mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-29 15:59:42 +02:00
Merge branch 'dev' into update-unit-test
This commit is contained in:
commit
bf0f6bd272
148 changed files with 6195 additions and 691 deletions
|
|
@ -14,6 +14,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type
|
|||
from pydantic import BaseModel, create_model, root_validator, validator
|
||||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.llm import BaseGPTAPI
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.postprecess.llm_output_postprecess import llm_output_postprecess
|
||||
|
|
@ -260,9 +261,10 @@ class ActionNode:
|
|||
output_data_mapping: dict,
|
||||
system_msgs: Optional[list[str]] = None,
|
||||
schema="markdown", # compatible to original format
|
||||
timeout=CONFIG.timeout,
|
||||
) -> (str, BaseModel):
|
||||
"""Use ActionOutput to wrap the output of aask"""
|
||||
content = await self.llm.aask(prompt, system_msgs)
|
||||
content = await self.llm.aask(prompt, system_msgs, timeout=timeout)
|
||||
logger.debug(f"llm raw output:\n{content}")
|
||||
output_class = self.create_model_class(output_class_name, output_data_mapping)
|
||||
|
||||
|
|
@ -289,13 +291,13 @@ class ActionNode:
|
|||
def set_context(self, context):
|
||||
self.set_recursive("context", context)
|
||||
|
||||
async def simple_fill(self, schema, mode):
|
||||
async def simple_fill(self, schema, mode, timeout=CONFIG.timeout):
|
||||
prompt = self.compile(context=self.context, schema=schema, mode=mode)
|
||||
|
||||
if schema != "raw":
|
||||
mapping = self.get_mapping(mode)
|
||||
class_name = f"{self.key}_AN"
|
||||
content, scontent = await self._aask_v1(prompt, class_name, mapping, schema=schema)
|
||||
content, scontent = await self._aask_v1(prompt, class_name, mapping, schema=schema, timeout=timeout)
|
||||
self.content = content
|
||||
self.instruct_content = scontent
|
||||
else:
|
||||
|
|
@ -304,7 +306,7 @@ class ActionNode:
|
|||
|
||||
return self
|
||||
|
||||
async def fill(self, context, llm, schema="json", mode="auto", strgy="simple"):
|
||||
async def fill(self, context, llm, schema="json", mode="auto", strgy="simple", timeout=CONFIG.timeout):
|
||||
"""Fill the node(s) with mode.
|
||||
|
||||
:param context: Everything we should know when filling node.
|
||||
|
|
@ -320,6 +322,7 @@ class ActionNode:
|
|||
:param strgy: simple/complex
|
||||
- simple: run only once
|
||||
- complex: run each node
|
||||
:param timeout: Timeout for llm invocation.
|
||||
:return: self
|
||||
"""
|
||||
self.set_llm(llm)
|
||||
|
|
@ -328,12 +331,12 @@ class ActionNode:
|
|||
schema = self.schema
|
||||
|
||||
if strgy == "simple":
|
||||
return await self.simple_fill(schema=schema, mode=mode)
|
||||
return await self.simple_fill(schema=schema, mode=mode, timeout=timeout)
|
||||
elif strgy == "complex":
|
||||
# 这里隐式假设了拥有children
|
||||
tmp = {}
|
||||
for _, i in self.children.items():
|
||||
child = await i.simple_fill(schema=schema, mode=mode)
|
||||
child = await i.simple_fill(schema=schema, mode=mode, timeout=timeout)
|
||||
tmp.update(child.instruct_content.dict())
|
||||
cls = self.create_children_class()
|
||||
self.instruct_content = cls(**tmp)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import Field
|
||||
|
|
@ -8,6 +7,7 @@ from metagpt.llm import LLM
|
|||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
from metagpt.utils.highlight import highlight
|
||||
|
||||
CLONE_PROMPT = """
|
||||
|
|
@ -39,7 +39,7 @@ class CloneFunction(WriteCode):
|
|||
if isinstance(code_path, str):
|
||||
code_path = Path(code_path)
|
||||
code_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
code_path.write_text(code)
|
||||
code_path.write_text(code, encoding="utf-8")
|
||||
logger.info(f"Saving Code to {code_path}")
|
||||
|
||||
async def run(self, template_func: str, source_code: str) -> str:
|
||||
|
|
@ -51,20 +51,17 @@ class CloneFunction(WriteCode):
|
|||
return code
|
||||
|
||||
|
||||
@handle_exception
|
||||
def run_function_code(func_code: str, func_name: str, *args, **kwargs):
|
||||
"""Run function code from string code."""
|
||||
try:
|
||||
locals_ = {}
|
||||
exec(func_code, locals_)
|
||||
func = locals_[func_name]
|
||||
return func(*args, **kwargs), ""
|
||||
except Exception:
|
||||
return "", traceback.format_exc()
|
||||
locals_ = {}
|
||||
exec(func_code, locals_)
|
||||
func = locals_[func_name]
|
||||
return func(*args, **kwargs), ""
|
||||
|
||||
|
||||
def run_function_script(code_script_path: str, func_name: str, *args, **kwargs):
|
||||
"""Run function code from script."""
|
||||
if isinstance(code_script_path, str):
|
||||
code_path = Path(code_script_path)
|
||||
code_path = Path(code_script_path)
|
||||
code = code_path.read_text(encoding="utf-8")
|
||||
return run_function_code(code, func_name, *args, **kwargs)
|
||||
|
|
|
|||
|
|
@ -19,5 +19,5 @@ class ExecuteTask(Action):
|
|||
context: list[Message] = []
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
async def run(self, *args, **kwargs):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -11,6 +11,3 @@ class FixBug(Action):
|
|||
"""Fix bug action without any implementation details"""
|
||||
|
||||
name: str = "FixBug"
|
||||
|
||||
async def run(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
68
metagpt/actions/rebuild_class_view.py
Normal file
68
metagpt/actions/rebuild_class_view.py
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/19
|
||||
@Author : mashenquan
|
||||
@File : rebuild_class_view.py
|
||||
@Desc : Rebuild class view info
|
||||
"""
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import CLASS_VIEW_FILE_REPO, GRAPH_REPO_FILE_REPO
|
||||
from metagpt.repo_parser import RepoParser
|
||||
from metagpt.utils.di_graph_repository import DiGraphRepository
|
||||
from metagpt.utils.graph_repository import GraphKeyword, GraphRepository
|
||||
|
||||
|
||||
class RebuildClassView(Action):
|
||||
def __init__(self, name="", context=None, llm=None):
|
||||
super().__init__(name=name, context=context, llm=llm)
|
||||
|
||||
async def run(self, with_messages=None, format=CONFIG.prompt_schema):
|
||||
graph_repo_pathname = CONFIG.git_repo.workdir / GRAPH_REPO_FILE_REPO / CONFIG.git_repo.workdir.name
|
||||
graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
|
||||
repo_parser = RepoParser(base_directory=self.context)
|
||||
class_views = await repo_parser.rebuild_class_views(path=Path(self.context)) # use pylint
|
||||
await GraphRepository.update_graph_db_with_class_views(graph_db, class_views)
|
||||
symbols = repo_parser.generate_symbols() # use ast
|
||||
for file_info in symbols:
|
||||
await GraphRepository.update_graph_db_with_file_info(graph_db, file_info)
|
||||
await self._create_mermaid_class_view(graph_db=graph_db)
|
||||
await self._save(graph_db=graph_db)
|
||||
|
||||
async def _create_mermaid_class_view(self, graph_db):
|
||||
pass
|
||||
# dataset = await graph_db.select(subject=concat_namespace(filename, class_name), predicate=GraphKeyword.HAS_PAGE_INFO)
|
||||
# if not dataset:
|
||||
# logger.warning(f"No page info for {concat_namespace(filename, class_name)}")
|
||||
# return
|
||||
# code_block_info = CodeBlockInfo.parse_raw(dataset[0].object_)
|
||||
# src_code = await read_file_block(filename=Path(self.context) / filename, lineno=code_block_info.lineno, end_lineno=code_block_info.end_lineno)
|
||||
# code_type = ""
|
||||
# dataset = await graph_db.select(subject=filename, predicate=GraphKeyword.IS)
|
||||
# for spo in dataset:
|
||||
# if spo.object_ in ["javascript", "python"]:
|
||||
# code_type = spo.object_
|
||||
# break
|
||||
|
||||
# try:
|
||||
# node = await REBUILD_CLASS_VIEW_NODE.fill(context=f"```{code_type}\n{src_code}\n```", llm=self.llm, to=format)
|
||||
# class_view = node.instruct_content.dict()["Class View"]
|
||||
# except Exception as e:
|
||||
# class_view = RepoParser.rebuild_class_view(src_code, code_type)
|
||||
# await graph_db.insert(subject=concat_namespace(filename, class_name), predicate=GraphKeyword.HAS_CLASS_VIEW, object_=class_view)
|
||||
# logger.info(f"{concat_namespace(filename, class_name)} {GraphKeyword.HAS_CLASS_VIEW} {class_view}")
|
||||
|
||||
async def _save(self, graph_db):
|
||||
class_view_file_repo = CONFIG.git_repo.new_file_repository(relative_path=CLASS_VIEW_FILE_REPO)
|
||||
dataset = await graph_db.select(predicate=GraphKeyword.HAS_CLASS_VIEW)
|
||||
all_class_view = []
|
||||
for spo in dataset:
|
||||
title = f"---\ntitle: {spo.subject}\n---\n"
|
||||
filename = re.sub(r"[/:]", "_", spo.subject) + ".mmd"
|
||||
await class_view_file_repo.save(filename=filename, content=title + spo.object_)
|
||||
all_class_view.append(spo.object_)
|
||||
await class_view_file_repo.save(filename="all.mmd", content="\n".join(all_class_view))
|
||||
33
metagpt/actions/rebuild_class_view_an.py
Normal file
33
metagpt/actions/rebuild_class_view_an.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/19
|
||||
@Author : mashenquan
|
||||
@File : rebuild_class_view_an.py
|
||||
@Desc : Defines `ActionNode` objects used by rebuild_class_view.py
|
||||
"""
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
|
||||
CLASS_SOURCE_CODE_BLOCK = ActionNode(
|
||||
key="Class View",
|
||||
expected_type=str,
|
||||
instruction='Generate the mermaid class diagram corresponding to source code in "context."',
|
||||
example="""
|
||||
classDiagram
|
||||
class A {
|
||||
-int x
|
||||
+int y
|
||||
-int speed
|
||||
-int direction
|
||||
+__init__(x: int, y: int, speed: int, direction: int)
|
||||
+change_direction(new_direction: int) None
|
||||
+move() None
|
||||
}
|
||||
""",
|
||||
)
|
||||
|
||||
REBUILD_CLASS_VIEW_NODES = [
|
||||
CLASS_SOURCE_CODE_BLOCK,
|
||||
]
|
||||
|
||||
REBUILD_CLASS_VIEW_NODE = ActionNode.from_children("RebuildClassView", REBUILD_CLASS_VIEW_NODES)
|
||||
|
|
@ -105,6 +105,7 @@ You are a member of a professional butler team and will provide helpful suggesti
|
|||
"""
|
||||
|
||||
|
||||
# TOTEST
|
||||
class SearchAndSummarize(Action):
|
||||
name: str = ""
|
||||
content: Optional[str] = None
|
||||
|
|
|
|||
111
metagpt/actions/skill_action.py
Normal file
111
metagpt/actions/skill_action.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/8/28
|
||||
@Author : mashenquan
|
||||
@File : skill_action.py
|
||||
@Desc : Call learned skill
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import importlib
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from typing import Dict, Optional
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.learn.skill_loader import Skill
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
# TOTEST
|
||||
class ArgumentsParingAction(Action):
|
||||
skill: Skill
|
||||
ask: str
|
||||
rsp: Optional[Message] = None
|
||||
args: Optional[Dict] = None
|
||||
|
||||
@property
|
||||
def prompt(self):
|
||||
prompt = "You are a function parser. You can convert spoken words into function parameters.\n"
|
||||
prompt += "\n---\n"
|
||||
prompt += f"{self.skill.name} function parameters description:\n"
|
||||
for k, v in self.skill.arguments.items():
|
||||
prompt += f"parameter `{k}`: {v}\n"
|
||||
prompt += "\n---\n"
|
||||
prompt += "Examples:\n"
|
||||
for e in self.skill.examples:
|
||||
prompt += f"If want you to do `{e.ask}`, return `{e.answer}` brief and clear.\n"
|
||||
prompt += "\n---\n"
|
||||
prompt += (
|
||||
f"\nRefer to the `{self.skill.name}` function description, and fill in the function parameters according "
|
||||
'to the example "I want you to do xx" in the Examples section.'
|
||||
f"\nNow I want you to do `{self.ask}`, return function parameters in Examples format above, brief and "
|
||||
"clear."
|
||||
)
|
||||
return prompt
|
||||
|
||||
async def run(self, with_message=None, **kwargs) -> Message:
|
||||
prompt = self.prompt
|
||||
rsp = await self.llm.aask(msg=prompt, system_msgs=[])
|
||||
logger.debug(f"SKILL:{prompt}\n, RESULT:{rsp}")
|
||||
self.args = ArgumentsParingAction.parse_arguments(skill_name=self.skill.name, txt=rsp)
|
||||
self.rsp = Message(content=rsp, role="assistant", instruct_content=self.args, cause_by=self)
|
||||
return self.rsp
|
||||
|
||||
@staticmethod
|
||||
def parse_arguments(skill_name, txt) -> dict:
|
||||
prefix = skill_name + "("
|
||||
if prefix not in txt:
|
||||
logger.error(f"{skill_name} not in {txt}")
|
||||
return None
|
||||
if ")" not in txt:
|
||||
logger.error(f"')' not in {txt}")
|
||||
return None
|
||||
begin_ix = txt.find(prefix)
|
||||
end_ix = txt.rfind(")")
|
||||
args_txt = txt[begin_ix + len(prefix) : end_ix]
|
||||
logger.info(args_txt)
|
||||
fake_expression = f"dict({args_txt})"
|
||||
parsed_expression = ast.parse(fake_expression, mode="eval")
|
||||
args = {}
|
||||
for keyword in parsed_expression.body.keywords:
|
||||
key = keyword.arg
|
||||
value = ast.literal_eval(keyword.value)
|
||||
args[key] = value
|
||||
return args
|
||||
|
||||
|
||||
class SkillAction(Action):
|
||||
skill: Skill
|
||||
args: Dict
|
||||
rsp: Optional[Message] = None
|
||||
|
||||
async def run(self, with_message=None, **kwargs) -> Message:
|
||||
"""Run action"""
|
||||
options = deepcopy(kwargs)
|
||||
if self.args:
|
||||
for k in self.args.keys():
|
||||
if k in options:
|
||||
options.pop(k)
|
||||
try:
|
||||
rsp = await self.find_and_call_function(self.skill.name, args=self.args, **options)
|
||||
self.rsp = Message(content=rsp, role="assistant", cause_by=self)
|
||||
except Exception as e:
|
||||
logger.exception(f"{e}, traceback:{traceback.format_exc()}")
|
||||
self.rsp = Message(content=f"Error: {e}", role="assistant", cause_by=self)
|
||||
return self.rsp
|
||||
|
||||
@staticmethod
|
||||
async def find_and_call_function(function_name, args, **kwargs) -> str:
|
||||
try:
|
||||
module = importlib.import_module("metagpt.learn")
|
||||
function = getattr(module, function_name)
|
||||
# Invoke function and return result
|
||||
result = await function(**args, **kwargs)
|
||||
return result
|
||||
except (ModuleNotFoundError, AttributeError):
|
||||
logger.error(f"{function_name} not found")
|
||||
raise ValueError(f"{function_name} not found")
|
||||
|
|
@ -91,6 +91,7 @@ flowchart TB
|
|||
"""
|
||||
|
||||
|
||||
# TOTEST
|
||||
class SummarizeCode(Action):
|
||||
name: str = "SummarizeCode"
|
||||
context: CodeSummarizeContext = Field(default_factory=CodeSummarizeContext)
|
||||
|
|
|
|||
163
metagpt/actions/talk_action.py
Normal file
163
metagpt/actions/talk_action.py
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/8/28
|
||||
@Author : mashenquan
|
||||
@File : talk_action.py
|
||||
@Desc : Act as it’s a talk
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import DEFAULT_LANGUAGE
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
# TOTEST
|
||||
class TalkAction(Action):
|
||||
context: str
|
||||
history_summary: str = ""
|
||||
knowledge: str = ""
|
||||
rsp: Optional[Message] = None
|
||||
|
||||
@property
|
||||
def prompt(self):
|
||||
prompt = ""
|
||||
if CONFIG.agent_description:
|
||||
prompt = (
|
||||
f"You are {CONFIG.agent_description}. Your responses should align with the role-play agreement, "
|
||||
f"maintaining the character's persona and habits. When faced with unrelated questions, playfully "
|
||||
f"decline to answer without revealing your AI nature to preserve the character's image.\n\n"
|
||||
)
|
||||
prompt += f"Knowledge:\n{self.knowledge}\n\n" if self.knowledge else ""
|
||||
prompt += f"{self.history_summary}\n\n"
|
||||
prompt += (
|
||||
"If the information is insufficient, you can search in the historical conversation or knowledge above.\n"
|
||||
)
|
||||
language = CONFIG.language or DEFAULT_LANGUAGE
|
||||
prompt += (
|
||||
f"Answer the following questions strictly in {language}, and the answers must follow the Markdown format.\n "
|
||||
f"{self.context}"
|
||||
)
|
||||
logger.debug(f"PROMPT: {prompt}")
|
||||
return prompt
|
||||
|
||||
@property
|
||||
def prompt_gpt4(self):
|
||||
kvs = {
|
||||
"{role}": CONFIG.agent_description or "",
|
||||
"{history}": self.history_summary or "",
|
||||
"{knowledge}": self.knowledge or "",
|
||||
"{language}": CONFIG.language or DEFAULT_LANGUAGE,
|
||||
"{ask}": self.context,
|
||||
}
|
||||
prompt = TalkActionPrompt.FORMATION_LOOSE
|
||||
for k, v in kvs.items():
|
||||
prompt = prompt.replace(k, v)
|
||||
logger.info(f"PROMPT: {prompt}")
|
||||
return prompt
|
||||
|
||||
# async def run_old(self, *args, **kwargs) -> ActionOutput:
|
||||
# prompt = self.prompt
|
||||
# rsp = await self.llm.aask(msg=prompt, system_msgs=[])
|
||||
# logger.debug(f"PROMPT:{prompt}\nRESULT:{rsp}\n")
|
||||
# self._rsp = ActionOutput(content=rsp)
|
||||
# return self._rsp
|
||||
|
||||
@property
|
||||
def aask_args(self):
|
||||
language = CONFIG.language or DEFAULT_LANGUAGE
|
||||
system_msgs = [
|
||||
f"You are {CONFIG.agent_description}.",
|
||||
"Your responses should align with the role-play agreement, "
|
||||
"maintaining the character's persona and habits. When faced with unrelated questions, playfully "
|
||||
"decline to answer without revealing your AI nature to preserve the character's image.",
|
||||
"If the information is insufficient, you can search in the context or knowledge.",
|
||||
f"Answer the following questions strictly in {language}, and the answers must follow the Markdown format.",
|
||||
]
|
||||
format_msgs = []
|
||||
if self.knowledge:
|
||||
format_msgs.append({"role": "assistant", "content": self.knowledge})
|
||||
if self.history_summary:
|
||||
format_msgs.append({"role": "assistant", "content": self.history_summary})
|
||||
return self.context, format_msgs, system_msgs
|
||||
|
||||
async def run(self, with_message=None, **kwargs) -> Message:
|
||||
msg, format_msgs, system_msgs = self.aask_args
|
||||
rsp = await self.llm.aask(msg=msg, format_msgs=format_msgs, system_msgs=system_msgs)
|
||||
self.rsp = Message(content=rsp, role="assistant", cause_by=self)
|
||||
return self.rsp
|
||||
|
||||
|
||||
class TalkActionPrompt:
|
||||
FORMATION = """Formation: "Capacity and role" defines the role you are currently playing;
|
||||
"[HISTORY_BEGIN]" and "[HISTORY_END]" tags enclose the historical conversation;
|
||||
"[KNOWLEDGE_BEGIN]" and "[KNOWLEDGE_END]" tags enclose the knowledge may help for your responses;
|
||||
"Statement" defines the work detail you need to complete at this stage;
|
||||
"[ASK_BEGIN]" and [ASK_END] tags enclose the questions;
|
||||
"Constraint" defines the conditions that your responses must comply with.
|
||||
"Personality" defines your language style。
|
||||
"Insight" provides a deeper understanding of the characters' inner traits.
|
||||
"Initial" defines the initial setup of a character.
|
||||
|
||||
Capacity and role: {role}
|
||||
Statement: Your responses should align with the role-play agreement, maintaining the
|
||||
character's persona and habits. When faced with unrelated questions, playfully decline to answer without revealing
|
||||
your AI nature to preserve the character's image.
|
||||
|
||||
[HISTORY_BEGIN]
|
||||
|
||||
{history}
|
||||
|
||||
[HISTORY_END]
|
||||
|
||||
[KNOWLEDGE_BEGIN]
|
||||
|
||||
{knowledge}
|
||||
|
||||
[KNOWLEDGE_END]
|
||||
|
||||
Statement: If the information is insufficient, you can search in the historical conversation or knowledge.
|
||||
Statement: Unless you are a language professional, answer the following questions strictly in {language}
|
||||
, and the answers must follow the Markdown format. Strictly excluding any tag likes "[HISTORY_BEGIN]"
|
||||
, "[HISTORY_END]", "[KNOWLEDGE_BEGIN]", "[KNOWLEDGE_END]" in responses.
|
||||
|
||||
|
||||
{ask}
|
||||
"""
|
||||
|
||||
FORMATION_LOOSE = """Formation: "Capacity and role" defines the role you are currently playing;
|
||||
"[HISTORY_BEGIN]" and "[HISTORY_END]" tags enclose the historical conversation;
|
||||
"[KNOWLEDGE_BEGIN]" and "[KNOWLEDGE_END]" tags enclose the knowledge may help for your responses;
|
||||
"Statement" defines the work detail you need to complete at this stage;
|
||||
"Constraint" defines the conditions that your responses must comply with.
|
||||
"Personality" defines your language style。
|
||||
"Insight" provides a deeper understanding of the characters' inner traits.
|
||||
"Initial" defines the initial setup of a character.
|
||||
|
||||
Capacity and role: {role}
|
||||
Statement: Your responses should maintaining the character's persona and habits. When faced with unrelated questions
|
||||
, playfully decline to answer without revealing your AI nature to preserve the character's image.
|
||||
|
||||
[HISTORY_BEGIN]
|
||||
|
||||
{history}
|
||||
|
||||
[HISTORY_END]
|
||||
|
||||
[KNOWLEDGE_BEGIN]
|
||||
|
||||
{knowledge}
|
||||
|
||||
[KNOWLEDGE_END]
|
||||
|
||||
Statement: If the information is insufficient, you can search in the historical conversation or knowledge.
|
||||
Statement: Unless you are a language professional, answer the following questions strictly in {language}
|
||||
, and the answers must follow the Markdown format. Strictly excluding any tag likes "[HISTORY_BEGIN]"
|
||||
, "[HISTORY_END]", "[KNOWLEDGE_BEGIN]", "[KNOWLEDGE_END]" in responses.
|
||||
|
||||
|
||||
{ask}
|
||||
"""
|
||||
|
|
@ -123,7 +123,7 @@ class WritePRD(Action):
|
|||
# logger.info(rsp)
|
||||
project_name = CONFIG.project_name if CONFIG.project_name else ""
|
||||
context = CONTEXT_TEMPLATE.format(requirements=requirements, project_name=project_name)
|
||||
node = await WRITE_PRD_NODE.fill(context=context, llm=self.llm, schema=schema)
|
||||
node = await WRITE_PRD_NODE.fill(context=context, llm=self.llm) # schema=schema
|
||||
await self._rename_workspace(node)
|
||||
return node
|
||||
|
||||
|
|
|
|||
193
metagpt/actions/write_teaching_plan.py
Normal file
193
metagpt/actions/write_teaching_plan.py
Normal file
|
|
@ -0,0 +1,193 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/7/27
|
||||
@Author : mashenquan
|
||||
@File : write_teaching_plan.py
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
|
||||
|
||||
class WriteTeachingPlanPart(Action):
|
||||
"""Write Teaching Plan Part"""
|
||||
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
topic: str = ""
|
||||
language: str = "Chinese"
|
||||
rsp: Optional[str] = None
|
||||
|
||||
async def run(self, with_message=None, **kwargs):
|
||||
statement_patterns = TeachingPlanBlock.TOPIC_STATEMENTS.get(self.topic, [])
|
||||
statements = []
|
||||
for p in statement_patterns:
|
||||
s = self.format_value(p)
|
||||
statements.append(s)
|
||||
formatter = (
|
||||
TeachingPlanBlock.PROMPT_TITLE_TEMPLATE
|
||||
if self.topic == TeachingPlanBlock.COURSE_TITLE
|
||||
else TeachingPlanBlock.PROMPT_TEMPLATE
|
||||
)
|
||||
prompt = formatter.format(
|
||||
formation=TeachingPlanBlock.FORMATION,
|
||||
role=self.prefix,
|
||||
statements="\n".join(statements),
|
||||
lesson=self.context,
|
||||
topic=self.topic,
|
||||
language=self.language,
|
||||
)
|
||||
|
||||
logger.debug(prompt)
|
||||
rsp = await self._aask(prompt=prompt)
|
||||
logger.debug(rsp)
|
||||
self._set_result(rsp)
|
||||
return self.rsp
|
||||
|
||||
def _set_result(self, rsp):
|
||||
if TeachingPlanBlock.DATA_BEGIN_TAG in rsp:
|
||||
ix = rsp.index(TeachingPlanBlock.DATA_BEGIN_TAG)
|
||||
rsp = rsp[ix + len(TeachingPlanBlock.DATA_BEGIN_TAG) :]
|
||||
if TeachingPlanBlock.DATA_END_TAG in rsp:
|
||||
ix = rsp.index(TeachingPlanBlock.DATA_END_TAG)
|
||||
rsp = rsp[0:ix]
|
||||
self.rsp = rsp.strip()
|
||||
if self.topic != TeachingPlanBlock.COURSE_TITLE:
|
||||
return
|
||||
if "#" not in self.rsp or self.rsp.index("#") != 0:
|
||||
self.rsp = "# " + self.rsp
|
||||
|
||||
def __str__(self):
|
||||
"""Return `topic` value when str()"""
|
||||
return self.topic
|
||||
|
||||
def __repr__(self):
|
||||
"""Show `topic` value when debug"""
|
||||
return self.topic
|
||||
|
||||
@staticmethod
|
||||
def format_value(value):
|
||||
"""Fill parameters inside `value` with `options`."""
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
if "{" not in value:
|
||||
return value
|
||||
|
||||
merged_opts = CONFIG.options or {}
|
||||
try:
|
||||
return value.format(**merged_opts)
|
||||
except KeyError as e:
|
||||
logger.warning(f"Parameter is missing:{e}")
|
||||
|
||||
for k, v in merged_opts.items():
|
||||
value = value.replace("{" + f"{k}" + "}", str(v))
|
||||
return value
|
||||
|
||||
|
||||
class TeachingPlanBlock:
|
||||
FORMATION = (
|
||||
'"Capacity and role" defines the role you are currently playing;\n'
|
||||
'\t"[LESSON_BEGIN]" and "[LESSON_END]" tags enclose the content of textbook;\n'
|
||||
'\t"Statement" defines the work detail you need to complete at this stage;\n'
|
||||
'\t"Answer options" defines the format requirements for your responses;\n'
|
||||
'\t"Constraint" defines the conditions that your responses must comply with.'
|
||||
)
|
||||
|
||||
COURSE_TITLE = "Title"
|
||||
TOPICS = [
|
||||
COURSE_TITLE,
|
||||
"Teaching Hours",
|
||||
"Teaching Objectives",
|
||||
"Teaching Content",
|
||||
"Teaching Methods and Strategies",
|
||||
"Learning Activities",
|
||||
"Teaching Time Allocation",
|
||||
"Assessment and Feedback",
|
||||
"Teaching Summary and Improvement",
|
||||
"Vocabulary Cloze",
|
||||
"Choice Questions",
|
||||
"Grammar Questions",
|
||||
"Translation Questions",
|
||||
]
|
||||
|
||||
TOPIC_STATEMENTS = {
|
||||
COURSE_TITLE: [
|
||||
"Statement: Find and return the title of the lesson only in markdown first-level header format, "
|
||||
"without anything else."
|
||||
],
|
||||
"Teaching Content": [
|
||||
'Statement: "Teaching Content" must include vocabulary, analysis, and examples of various grammar '
|
||||
"structures that appear in the textbook, as well as the listening materials and key points.",
|
||||
'Statement: "Teaching Content" must include more examples.',
|
||||
],
|
||||
"Teaching Time Allocation": [
|
||||
'Statement: "Teaching Time Allocation" must include how much time is allocated to each '
|
||||
"part of the textbook content."
|
||||
],
|
||||
"Teaching Methods and Strategies": [
|
||||
'Statement: "Teaching Methods and Strategies" must include teaching focus, difficulties, materials, '
|
||||
"procedures, in detail."
|
||||
],
|
||||
"Vocabulary Cloze": [
|
||||
'Statement: Based on the content of the textbook enclosed by "[LESSON_BEGIN]" and "[LESSON_END]", '
|
||||
"create vocabulary cloze. The cloze should include 10 {language} questions with {teaching_language} "
|
||||
"answers, and it should also include 10 {teaching_language} questions with {language} answers. "
|
||||
"The key-related vocabulary and phrases in the textbook content must all be included in the exercises.",
|
||||
],
|
||||
"Grammar Questions": [
|
||||
'Statement: Based on the content of the textbook enclosed by "[LESSON_BEGIN]" and "[LESSON_END]", '
|
||||
"create grammar questions. 10 questions."
|
||||
],
|
||||
"Choice Questions": [
|
||||
'Statement: Based on the content of the textbook enclosed by "[LESSON_BEGIN]" and "[LESSON_END]", '
|
||||
"create choice questions. 10 questions."
|
||||
],
|
||||
"Translation Questions": [
|
||||
'Statement: Based on the content of the textbook enclosed by "[LESSON_BEGIN]" and "[LESSON_END]", '
|
||||
"create translation questions. The translation should include 10 {language} questions with "
|
||||
"{teaching_language} answers, and it should also include 10 {teaching_language} questions with "
|
||||
"{language} answers."
|
||||
],
|
||||
}
|
||||
|
||||
# Teaching plan title
|
||||
PROMPT_TITLE_TEMPLATE = (
|
||||
"Do not refer to the context of the previous conversation records, "
|
||||
"start the conversation anew.\n\n"
|
||||
"Formation: {formation}\n\n"
|
||||
"{statements}\n"
|
||||
"Constraint: Writing in {language}.\n"
|
||||
'Answer options: Encloses the lesson title with "[TEACHING_PLAN_BEGIN]" '
|
||||
'and "[TEACHING_PLAN_END]" tags.\n'
|
||||
"[LESSON_BEGIN]\n"
|
||||
"{lesson}\n"
|
||||
"[LESSON_END]"
|
||||
)
|
||||
|
||||
# Teaching plan parts:
|
||||
PROMPT_TEMPLATE = (
|
||||
"Do not refer to the context of the previous conversation records, "
|
||||
"start the conversation anew.\n\n"
|
||||
"Formation: {formation}\n\n"
|
||||
"Capacity and role: {role}\n"
|
||||
'Statement: Write the "{topic}" part of teaching plan, '
|
||||
'WITHOUT ANY content unrelated to "{topic}"!!\n'
|
||||
"{statements}\n"
|
||||
'Answer options: Enclose the teaching plan content with "[TEACHING_PLAN_BEGIN]" '
|
||||
'and "[TEACHING_PLAN_END]" tags.\n'
|
||||
"Answer options: Using proper markdown format from second-level header format.\n"
|
||||
"Constraint: Writing in {language}.\n"
|
||||
"[LESSON_BEGIN]\n"
|
||||
"{lesson}\n"
|
||||
"[LESSON_END]"
|
||||
)
|
||||
|
||||
DATA_BEGIN_TAG = "[TEACHING_PLAN_BEGIN]"
|
||||
DATA_END_TAG = "[TEACHING_PLAN_END]"
|
||||
|
|
@ -44,7 +44,7 @@ you should correctly import the necessary classes based on these file locations!
|
|||
|
||||
class WriteTest(Action):
|
||||
name: str = "WriteTest"
|
||||
context: Optional[str] = None
|
||||
context: Optional[TestingContext] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
|
||||
async def write_code(self, prompt):
|
||||
|
|
|
|||
|
|
@ -6,12 +6,15 @@ Provide configuration, singleton
|
|||
1. According to Section 2.2.3.11 of RFC 135, add git repository support.
|
||||
2. Add the parameter `src_workspace` for the old version project path.
|
||||
"""
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import yaml
|
||||
|
||||
|
|
@ -19,6 +22,7 @@ from metagpt.const import DEFAULT_WORKSPACE_ROOT, METAGPT_ROOT, OPTIONS
|
|||
from metagpt.logs import logger
|
||||
from metagpt.tools import SearchEngineType, WebBrowserEngineType
|
||||
from metagpt.utils.common import require_python_version
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.singleton import Singleton
|
||||
|
||||
|
||||
|
|
@ -42,6 +46,8 @@ class LLMProviderEnum(Enum):
|
|||
FIREWORKS = "fireworks"
|
||||
OPEN_LLM = "open_llm"
|
||||
GEMINI = "gemini"
|
||||
METAGPT = "metagpt"
|
||||
AZURE_OPENAI = "azure_openai"
|
||||
OLLAMA = "ollama"
|
||||
|
||||
|
||||
|
|
@ -58,7 +64,7 @@ class Config(metaclass=Singleton):
|
|||
key_yaml_file = METAGPT_ROOT / "config/key.yaml"
|
||||
default_yaml_file = METAGPT_ROOT / "config/config.yaml"
|
||||
|
||||
def __init__(self, yaml_file=default_yaml_file):
|
||||
def __init__(self, yaml_file=default_yaml_file, cost_data=""):
|
||||
global_options = OPTIONS.get()
|
||||
# cli paras
|
||||
self.project_path = ""
|
||||
|
|
@ -68,32 +74,57 @@ class Config(metaclass=Singleton):
|
|||
self.max_auto_summarize_code = 0
|
||||
|
||||
self._init_with_config_files_and_env(yaml_file)
|
||||
# The agent needs to be billed per user, so billing information cannot be destroyed when the session ends.
|
||||
self.cost_manager = CostManager(**json.loads(cost_data)) if cost_data else CostManager()
|
||||
self._update()
|
||||
global_options.update(OPTIONS.get())
|
||||
logger.debug("Config loading done.")
|
||||
|
||||
def get_default_llm_provider_enum(self) -> LLMProviderEnum:
|
||||
for k, v in [
|
||||
(self.openai_api_key, LLMProviderEnum.OPENAI),
|
||||
(self.anthropic_api_key, LLMProviderEnum.ANTHROPIC),
|
||||
(self.zhipuai_api_key, LLMProviderEnum.ZHIPUAI),
|
||||
(self.fireworks_api_key, LLMProviderEnum.FIREWORKS),
|
||||
(self.open_llm_api_base, LLMProviderEnum.OPEN_LLM),
|
||||
(self.gemini_api_key, LLMProviderEnum.GEMINI),
|
||||
(self.ollama_api_base, LLMProviderEnum.OLLAMA), # reuse logic. but not a key
|
||||
]:
|
||||
if self._is_valid_llm_key(k):
|
||||
# logger.debug(f"Use LLMProvider: {v.value}")
|
||||
if v == LLMProviderEnum.GEMINI and not require_python_version(req_version=(3, 10)):
|
||||
warnings.warn("Use Gemini requires Python >= 3.10")
|
||||
if self.openai_api_key and self.openai_api_model:
|
||||
logger.info(f"OpenAI API Model: {self.openai_api_model}")
|
||||
return v
|
||||
"""Get first valid LLM provider enum"""
|
||||
mappings = {
|
||||
LLMProviderEnum.OPENAI: bool(
|
||||
self._is_valid_llm_key(self.OPENAI_API_KEY) and not self.OPENAI_API_TYPE and self.OPENAI_API_MODEL
|
||||
),
|
||||
LLMProviderEnum.ANTHROPIC: self._is_valid_llm_key(self.ANTHROPIC_API_KEY),
|
||||
LLMProviderEnum.ZHIPUAI: self._is_valid_llm_key(self.ZHIPUAI_API_KEY),
|
||||
LLMProviderEnum.FIREWORKS: self._is_valid_llm_key(self.FIREWORKS_API_KEY),
|
||||
LLMProviderEnum.OPEN_LLM: self._is_valid_llm_key(self.OPEN_LLM_API_BASE),
|
||||
LLMProviderEnum.GEMINI: self._is_valid_llm_key(self.GEMINI_API_KEY),
|
||||
LLMProviderEnum.METAGPT: bool(
|
||||
self._is_valid_llm_key(self.OPENAI_API_KEY) and self.OPENAI_API_TYPE == "metagpt"
|
||||
),
|
||||
LLMProviderEnum.AZURE_OPENAI: bool(
|
||||
self._is_valid_llm_key(self.OPENAI_API_KEY)
|
||||
and self.OPENAI_API_TYPE == "azure"
|
||||
and self.DEPLOYMENT_NAME
|
||||
and self.OPENAI_API_VERSION
|
||||
),
|
||||
LLMProviderEnum.OLLAMA: self._is_valid_llm_key(self.OLLAMA_API_BASE),
|
||||
}
|
||||
provider = None
|
||||
for k, v in mappings.items():
|
||||
if v:
|
||||
provider = k
|
||||
break
|
||||
|
||||
if provider is LLMProviderEnum.GEMINI and not require_python_version(req_version=(3, 10)):
|
||||
warnings.warn("Use Gemini requires Python >= 3.10")
|
||||
model_mappings = {
|
||||
LLMProviderEnum.OPENAI: self.OPENAI_API_MODEL,
|
||||
LLMProviderEnum.AZURE_OPENAI: self.DEPLOYMENT_NAME,
|
||||
}
|
||||
model_name = model_mappings.get(provider)
|
||||
if model_name:
|
||||
logger.info(f"{provider} Model: {model_name}")
|
||||
if provider:
|
||||
logger.info(f"API: {provider}")
|
||||
return provider
|
||||
raise NotConfiguredException("You should config a LLM configuration first")
|
||||
|
||||
@staticmethod
|
||||
def _is_valid_llm_key(k: str) -> bool:
|
||||
return k and k != "YOUR_API_KEY"
|
||||
return bool(k and k != "YOUR_API_KEY")
|
||||
|
||||
def _update(self):
|
||||
self.global_proxy = self._get("GLOBAL_PROXY")
|
||||
|
|
@ -111,7 +142,7 @@ class Config(metaclass=Singleton):
|
|||
if not self._get("DISABLE_LLM_PROVIDER_CHECK"):
|
||||
_ = self.get_default_llm_provider_enum()
|
||||
|
||||
self.openai_base_url = self._get("OPENAI_BASE_URL")
|
||||
# self.openai_base_url = self._get("OPENAI_BASE_URL")
|
||||
self.openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy
|
||||
self.openai_api_type = self._get("OPENAI_API_TYPE")
|
||||
self.openai_api_version = self._get("OPENAI_API_VERSION")
|
||||
|
|
@ -142,8 +173,7 @@ class Config(metaclass=Singleton):
|
|||
self.long_term_memory = self._get("LONG_TERM_MEMORY", False)
|
||||
if self.long_term_memory:
|
||||
logger.warning("LONG_TERM_MEMORY is True")
|
||||
self.max_budget = self._get("MAX_BUDGET", 10.0)
|
||||
self.total_cost = 0.0
|
||||
self.cost_manager.max_budget = self._get("MAX_BUDGET", 10.0)
|
||||
self.code_review_k_times = 2
|
||||
|
||||
self.puppeteer_config = self._get("PUPPETEER_CONFIG", "")
|
||||
|
|
@ -154,10 +184,18 @@ class Config(metaclass=Singleton):
|
|||
self.mermaid_engine = self._get("MERMAID_ENGINE", "nodejs")
|
||||
self.pyppeteer_executable_path = self._get("PYPPETEER_EXECUTABLE_PATH", "")
|
||||
|
||||
workspace_uid = (
|
||||
self._get("WORKSPACE_UID") or f"{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}-{uuid4().hex[-8:]}"
|
||||
)
|
||||
self.repair_llm_output = self._get("REPAIR_LLM_OUTPUT", False)
|
||||
self.prompt_schema = self._get("PROMPT_FORMAT", "json")
|
||||
self.workspace_path = Path(self._get("WORKSPACE_PATH", DEFAULT_WORKSPACE_ROOT))
|
||||
val = self._get("WORKSPACE_PATH_WITH_UID")
|
||||
if val and val.lower() == "true": # for agent
|
||||
self.workspace_path = self.workspace_path / workspace_uid
|
||||
self._ensure_workspace_exists()
|
||||
self.max_auto_summarize_code = self.max_auto_summarize_code or self._get("MAX_AUTO_SUMMARIZE_CODE", 1)
|
||||
self.timeout = int(self._get("TIMEOUT", 3))
|
||||
|
||||
def update_via_cli(self, project_path, project_name, inc, reqa_file, max_auto_summarize_code):
|
||||
"""update config via cli"""
|
||||
|
|
@ -198,7 +236,8 @@ class Config(metaclass=Singleton):
|
|||
return i.get(*args, **kwargs)
|
||||
|
||||
def get(self, key, *args, **kwargs):
|
||||
"""Search for a value in config/key.yaml, config/config.yaml, and env; raise an error if not found"""
|
||||
"""Retrieve values from config/key.yaml, config/config.yaml, and environment variables.
|
||||
Throw an error if not found."""
|
||||
value = self._get(key, *args, **kwargs)
|
||||
if value is None:
|
||||
raise ValueError(f"Key '{key}' not found in environment variables or in the YAML file")
|
||||
|
|
|
|||
|
|
@ -48,9 +48,10 @@ def get_metagpt_root():
|
|||
|
||||
# METAGPT PROJECT ROOT AND VARS
|
||||
|
||||
METAGPT_ROOT = get_metagpt_root()
|
||||
METAGPT_ROOT = get_metagpt_root() # Dependent on METAGPT_PROJECT_ROOT
|
||||
DEFAULT_WORKSPACE_ROOT = METAGPT_ROOT / "workspace"
|
||||
|
||||
EXAMPLE_PATH = METAGPT_ROOT / "examples"
|
||||
DATA_PATH = METAGPT_ROOT / "data"
|
||||
RESEARCH_PATH = DATA_PATH / "research"
|
||||
TUTORIAL_PATH = DATA_PATH / "tutorial_docx"
|
||||
|
|
@ -100,7 +101,27 @@ TEST_CODES_FILE_REPO = "tests"
|
|||
TEST_OUTPUTS_FILE_REPO = "test_outputs"
|
||||
CODE_SUMMARIES_FILE_REPO = "docs/code_summaries"
|
||||
CODE_SUMMARIES_PDF_FILE_REPO = "resources/code_summaries"
|
||||
RESOURCES_FILE_REPO = "resources"
|
||||
SD_OUTPUT_FILE_REPO = "resources/SD_Output"
|
||||
GRAPH_REPO_FILE_REPO = "docs/graph_repo"
|
||||
CLASS_VIEW_FILE_REPO = "docs/class_views"
|
||||
|
||||
YAPI_URL = "http://yapi.deepwisdomai.com/"
|
||||
|
||||
DEFAULT_LANGUAGE = "English"
|
||||
DEFAULT_MAX_TOKENS = 1500
|
||||
COMMAND_TOKENS = 500
|
||||
BRAIN_MEMORY = "BRAIN_MEMORY"
|
||||
SKILL_PATH = "SKILL_PATH"
|
||||
SERPER_API_KEY = "SERPER_API_KEY"
|
||||
DEFAULT_TOKEN_SIZE = 500
|
||||
|
||||
# format
|
||||
BASE64_FORMAT = "base64"
|
||||
|
||||
# REDIS
|
||||
REDIS_KEY = "REDIS_KEY"
|
||||
LLM_API_TIMEOUT = 300
|
||||
|
||||
# Message id
|
||||
IGNORED_MESSAGE_ID = "0"
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from langchain.embeddings import OpenAIEmbeddings
|
|||
from langchain.vectorstores import FAISS
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import DATA_PATH
|
||||
from metagpt.document import IndexableDocument
|
||||
from metagpt.document_store.base_store import LocalStore
|
||||
|
|
@ -25,7 +26,9 @@ class FaissStore(LocalStore):
|
|||
):
|
||||
self.meta_col = meta_col
|
||||
self.content_col = content_col
|
||||
self.embedding = embedding or OpenAIEmbeddings()
|
||||
self.embedding = embedding or OpenAIEmbeddings(
|
||||
openai_api_key=CONFIG.openai_api_key, openai_api_base=CONFIG.openai_base_url
|
||||
)
|
||||
super().__init__(raw_data, cache_dir)
|
||||
|
||||
def _load(self) -> Optional["FaissStore"]:
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from typing import Iterable, Set
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.role import Role, role_subclass_registry
|
||||
from metagpt.schema import Message
|
||||
|
|
@ -108,7 +109,7 @@ class Environment(BaseModel):
|
|||
for role in roles: # setup system message with roles
|
||||
role.set_env(self)
|
||||
|
||||
def publish_message(self, message: Message) -> bool:
|
||||
def publish_message(self, message: Message, peekable: bool = True) -> bool:
|
||||
"""
|
||||
Distribute the message to the recipients.
|
||||
In accordance with the Message routing structure design in Chapter 2.2.1 of RFC 116, as already planned
|
||||
|
|
@ -173,3 +174,8 @@ class Environment(BaseModel):
|
|||
def set_subscription(self, obj, tags):
|
||||
"""Set the labels for message to be consumed by the object"""
|
||||
self.members[obj] = tags
|
||||
|
||||
@staticmethod
|
||||
def archive(auto_archive=True):
|
||||
if auto_archive and CONFIG.git_repo:
|
||||
CONFIG.git_repo.archive()
|
||||
|
|
|
|||
|
|
@ -5,3 +5,9 @@
|
|||
@Author : alexanderwu
|
||||
@File : __init__.py
|
||||
"""
|
||||
|
||||
from metagpt.learn.text_to_image import text_to_image
|
||||
from metagpt.learn.text_to_speech import text_to_speech
|
||||
from metagpt.learn.google_search import google_search
|
||||
|
||||
__all__ = ["text_to_image", "text_to_speech", "google_search"]
|
||||
|
|
|
|||
12
metagpt/learn/google_search.py
Normal file
12
metagpt/learn/google_search.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
from metagpt.tools.search_engine import SearchEngine
|
||||
|
||||
|
||||
async def google_search(query: str, max_results: int = 6, **kwargs):
|
||||
"""Perform a web search and retrieve search results.
|
||||
|
||||
:param query: The search query.
|
||||
:param max_results: The number of search results to retrieve
|
||||
:return: The web search results in markdown format.
|
||||
"""
|
||||
resluts = await SearchEngine().run(query, max_results=max_results, as_string=False)
|
||||
return "\n".join(f"{i}. [{j['title']}]({j['link']}): {j['snippet']}" for i, j in enumerate(resluts, 1))
|
||||
100
metagpt/learn/skill_loader.py
Normal file
100
metagpt/learn/skill_loader.py
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/8/18
|
||||
@Author : mashenquan
|
||||
@File : skill_loader.py
|
||||
@Desc : Skill YAML Configuration Loader.
|
||||
"""
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import aiofiles
|
||||
import yaml
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
|
||||
|
||||
class Example(BaseModel):
|
||||
ask: str
|
||||
answer: str
|
||||
|
||||
|
||||
class Returns(BaseModel):
|
||||
type: str
|
||||
format: Optional[str] = None
|
||||
|
||||
|
||||
class Parameter(BaseModel):
|
||||
type: str
|
||||
description: str = None
|
||||
|
||||
|
||||
class Skill(BaseModel):
|
||||
name: str
|
||||
description: str = None
|
||||
id: str = None
|
||||
x_prerequisite: Dict = Field(default=None, alias="x-prerequisite")
|
||||
parameters: Dict[str, Parameter] = None
|
||||
examples: List[Example]
|
||||
returns: Returns
|
||||
|
||||
@property
|
||||
def arguments(self) -> Dict:
|
||||
if not self.parameters:
|
||||
return {}
|
||||
ret = {}
|
||||
for k, v in self.parameters.items():
|
||||
ret[k] = v.description if v.description else ""
|
||||
return ret
|
||||
|
||||
|
||||
class Entity(BaseModel):
|
||||
name: str = None
|
||||
skills: List[Skill]
|
||||
|
||||
|
||||
class Components(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class SkillsDeclaration(BaseModel):
|
||||
skillapi: str
|
||||
entities: Dict[str, Entity]
|
||||
components: Components = None
|
||||
|
||||
@staticmethod
|
||||
async def load(skill_yaml_file_name: Path = None) -> "SkillsDeclaration":
|
||||
if not skill_yaml_file_name:
|
||||
skill_yaml_file_name = Path(__file__).parent.parent.parent / ".well-known/skills.yaml"
|
||||
async with aiofiles.open(str(skill_yaml_file_name), mode="r") as reader:
|
||||
data = await reader.read(-1)
|
||||
skill_data = yaml.safe_load(data)
|
||||
return SkillsDeclaration(**skill_data)
|
||||
|
||||
def get_skill_list(self, entity_name: str = "Assistant") -> Dict:
|
||||
"""Return the skill name based on the skill description."""
|
||||
entity = self.entities.get(entity_name)
|
||||
if not entity:
|
||||
return {}
|
||||
|
||||
# List of skills that the agent chooses to activate.
|
||||
agent_skills = CONFIG.agent_skills
|
||||
if not agent_skills:
|
||||
return {}
|
||||
|
||||
class _AgentSkill(BaseModel):
|
||||
name: str
|
||||
|
||||
names = [_AgentSkill(**i).name for i in agent_skills]
|
||||
return {s.description: s.name for s in entity.skills if s.name in names}
|
||||
|
||||
def get_skill(self, name, entity_name: str = "Assistant") -> Skill:
|
||||
"""Return a skill by name."""
|
||||
entity = self.entities.get(entity_name)
|
||||
if not entity:
|
||||
return None
|
||||
for sk in entity.skills:
|
||||
if sk.name == name:
|
||||
return sk
|
||||
24
metagpt/learn/text_to_embedding.py
Normal file
24
metagpt/learn/text_to_embedding.py
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/8/18
|
||||
@Author : mashenquan
|
||||
@File : text_to_embedding.py
|
||||
@Desc : Text-to-Embedding skill, which provides text-to-embedding functionality.
|
||||
"""
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.tools.openai_text_to_embedding import oas3_openai_text_to_embedding
|
||||
|
||||
|
||||
async def text_to_embedding(text, model="text-embedding-ada-002", openai_api_key="", **kwargs):
|
||||
"""Text to embedding
|
||||
|
||||
:param text: The text used for embedding.
|
||||
:param model: One of ['text-embedding-ada-002'], ID of the model to use. For more details, checkout: `https://api.openai.com/v1/models`.
|
||||
:param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`
|
||||
:return: A json object of :class:`ResultEmbedding` class if successful, otherwise `{}`.
|
||||
"""
|
||||
if CONFIG.OPENAI_API_KEY or openai_api_key:
|
||||
return await oas3_openai_text_to_embedding(text, model=model, openai_api_key=openai_api_key)
|
||||
raise EnvironmentError
|
||||
38
metagpt/learn/text_to_image.py
Normal file
38
metagpt/learn/text_to_image.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/8/18
|
||||
@Author : mashenquan
|
||||
@File : text_to_image.py
|
||||
@Desc : Text-to-Image skill, which provides text-to-image functionality.
|
||||
"""
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import BASE64_FORMAT
|
||||
from metagpt.tools.metagpt_text_to_image import oas3_metagpt_text_to_image
|
||||
from metagpt.tools.openai_text_to_image import oas3_openai_text_to_image
|
||||
from metagpt.utils.s3 import S3
|
||||
|
||||
|
||||
async def text_to_image(text, size_type: str = "512x512", openai_api_key="", model_url="", **kwargs):
|
||||
"""Text to image
|
||||
|
||||
:param text: The text used for image conversion.
|
||||
:param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`
|
||||
:param size_type: If using OPENAI, the available size options are ['256x256', '512x512', '1024x1024'], while for MetaGPT, the options are ['512x512', '512x768'].
|
||||
:param model_url: MetaGPT model url
|
||||
:return: The image data is returned in Base64 encoding.
|
||||
"""
|
||||
image_declaration = "data:image/png;base64,"
|
||||
if CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL or model_url:
|
||||
base64_data = await oas3_metagpt_text_to_image(text, size_type, model_url)
|
||||
elif CONFIG.OPENAI_API_KEY or openai_api_key:
|
||||
base64_data = await oas3_openai_text_to_image(text, size_type)
|
||||
else:
|
||||
raise ValueError("Missing necessary parameters.")
|
||||
|
||||
s3 = S3()
|
||||
url = await s3.cache(data=base64_data, file_ext=".png", format=BASE64_FORMAT) if s3.is_valid else ""
|
||||
if url:
|
||||
return f""
|
||||
return image_declaration + base64_data if base64_data else ""
|
||||
72
metagpt/learn/text_to_speech.py
Normal file
72
metagpt/learn/text_to_speech.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/8/17
|
||||
@Author : mashenquan
|
||||
@File : text_to_speech.py
|
||||
@Desc : Text-to-Speech skill, which provides text-to-speech functionality
|
||||
"""
|
||||
import openai
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import BASE64_FORMAT
|
||||
from metagpt.tools.azure_tts import oas3_azsure_tts
|
||||
from metagpt.tools.iflytek_tts import oas3_iflytek_tts
|
||||
from metagpt.utils.s3 import S3
|
||||
|
||||
|
||||
async def text_to_speech(
|
||||
text,
|
||||
lang="zh-CN",
|
||||
voice="zh-CN-XiaomoNeural",
|
||||
style="affectionate",
|
||||
role="Girl",
|
||||
subscription_key="",
|
||||
region="",
|
||||
iflytek_app_id="",
|
||||
iflytek_api_key="",
|
||||
iflytek_api_secret="",
|
||||
**kwargs,
|
||||
):
|
||||
"""Text to speech
|
||||
For more details, check out:`https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`
|
||||
|
||||
:param lang: The value can contain a language code such as en (English), or a locale such as en-US (English - United States). For more details, checkout: `https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`
|
||||
:param voice: For more details, checkout: `https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`, `https://speech.microsoft.com/portal/voicegallery`
|
||||
:param style: Speaking style to express different emotions like cheerfulness, empathy, and calm. For more details, checkout: `https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`
|
||||
:param role: With roles, the same voice can act as a different age and gender. For more details, checkout: `https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`
|
||||
:param text: The text used for voice conversion.
|
||||
:param subscription_key: key is used to access your Azure AI service API, see: `https://portal.azure.com/` > `Resource Management` > `Keys and Endpoint`
|
||||
:param region: This is the location (or region) of your resource. You may need to use this field when making calls to this API.
|
||||
:param iflytek_app_id: Application ID is used to access your iFlyTek service API, see: `https://console.xfyun.cn/services/tts`
|
||||
:param iflytek_api_key: WebAPI argument, see: `https://console.xfyun.cn/services/tts`
|
||||
:param iflytek_api_secret: WebAPI argument, see: `https://console.xfyun.cn/services/tts`
|
||||
:return: Returns the Base64-encoded .wav/.mp3 file data if successful, otherwise an empty string.
|
||||
|
||||
"""
|
||||
|
||||
if (CONFIG.AZURE_TTS_SUBSCRIPTION_KEY and CONFIG.AZURE_TTS_REGION) or (subscription_key and region):
|
||||
audio_declaration = "data:audio/wav;base64,"
|
||||
base64_data = await oas3_azsure_tts(text, lang, voice, style, role, subscription_key, region)
|
||||
s3 = S3()
|
||||
url = await s3.cache(data=base64_data, file_ext=".wav", format=BASE64_FORMAT) if s3.is_valid else ""
|
||||
if url:
|
||||
return f"[{text}]({url})"
|
||||
return audio_declaration + base64_data if base64_data else base64_data
|
||||
if (CONFIG.IFLYTEK_APP_ID and CONFIG.IFLYTEK_API_KEY and CONFIG.IFLYTEK_API_SECRET) or (
|
||||
iflytek_app_id and iflytek_api_key and iflytek_api_secret
|
||||
):
|
||||
audio_declaration = "data:audio/mp3;base64,"
|
||||
base64_data = await oas3_iflytek_tts(
|
||||
text=text, app_id=iflytek_app_id, api_key=iflytek_api_key, api_secret=iflytek_api_secret
|
||||
)
|
||||
s3 = S3()
|
||||
url = await s3.cache(data=base64_data, file_ext=".mp3", format=BASE64_FORMAT) if s3.is_valid else ""
|
||||
if url:
|
||||
return f"[{text}]({url})"
|
||||
return audio_declaration + base64_data if base64_data else base64_data
|
||||
|
||||
raise openai.InvalidRequestError(
|
||||
message="AZURE_TTS_SUBSCRIPTION_KEY, AZURE_TTS_REGION, IFLYTEK_APP_ID, IFLYTEK_API_KEY, IFLYTEK_API_SECRET error",
|
||||
param={},
|
||||
)
|
||||
|
|
@ -4,11 +4,11 @@
|
|||
@Time : 2023/6/5 01:44
|
||||
@Author : alexanderwu
|
||||
@File : skill_manager.py
|
||||
@Modified By: mashenquan, 2023/8/20. Remove useless `_llm`
|
||||
"""
|
||||
from metagpt.actions import Action
|
||||
from metagpt.const import PROMPT_PATH
|
||||
from metagpt.document_store.chromadb_store import ChromaStore
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
|
||||
Skill = Action
|
||||
|
|
@ -18,7 +18,6 @@ class SkillManager:
|
|||
"""Used to manage all skills"""
|
||||
|
||||
def __init__(self):
|
||||
self._llm = LLM()
|
||||
self._store = ChromaStore("skill_manager")
|
||||
self._skills: dict[str:Skill] = {}
|
||||
|
||||
|
|
|
|||
253
metagpt/memory/brain_memory.py
Normal file
253
metagpt/memory/brain_memory.py
Normal file
|
|
@ -0,0 +1,253 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/8/18
|
||||
@Author : mashenquan
|
||||
@File : brain_memory.py
|
||||
@Desc : Used by AgentStore. Used for long-term storage and automatic compression.
|
||||
@Modified By: mashenquan, 2023/9/4. + redis memory cache.
|
||||
@Modified By: mashenquan, 2023/12/25. Simplify Functionality.
|
||||
"""
|
||||
import json
|
||||
import re
|
||||
from typing import Dict, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import DEFAULT_LANGUAGE
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider import MetaGPTAPI
|
||||
from metagpt.schema import Message, SimpleMessage
|
||||
from metagpt.utils.redis import Redis
|
||||
|
||||
|
||||
class BrainMemory(BaseModel):
|
||||
history: List[Message] = Field(default_factory=list)
|
||||
knowledge: List[Message] = Field(default_factory=list)
|
||||
historical_summary: str = ""
|
||||
last_history_id: str = ""
|
||||
is_dirty: bool = False
|
||||
last_talk: str = None
|
||||
cacheable: bool = True
|
||||
|
||||
def add_talk(self, msg: Message):
|
||||
"""
|
||||
Add message from user.
|
||||
"""
|
||||
msg.role = "user"
|
||||
self.add_history(msg)
|
||||
self.is_dirty = True
|
||||
|
||||
def add_answer(self, msg: Message):
|
||||
"""Add message from LLM"""
|
||||
msg.role = "assistant"
|
||||
self.add_history(msg)
|
||||
self.is_dirty = True
|
||||
|
||||
def get_knowledge(self) -> str:
|
||||
texts = [m.content for m in self.knowledge]
|
||||
return "\n".join(texts)
|
||||
|
||||
@staticmethod
|
||||
async def loads(redis_key: str, redis_conf: Dict = None) -> "BrainMemory":
|
||||
redis = Redis(conf=redis_conf)
|
||||
if not redis.is_valid() or not redis_key:
|
||||
return BrainMemory()
|
||||
v = await redis.get(key=redis_key)
|
||||
logger.debug(f"REDIS GET {redis_key} {v}")
|
||||
if v:
|
||||
bm = BrainMemory.parse_raw(v)
|
||||
bm.is_dirty = False
|
||||
return bm
|
||||
return BrainMemory()
|
||||
|
||||
async def dumps(self, redis_key: str, timeout_sec: int = 30 * 60, redis_conf: Dict = None):
|
||||
if not self.is_dirty:
|
||||
return
|
||||
redis = Redis(conf=redis_conf)
|
||||
if not redis.is_valid() or not redis_key:
|
||||
return False
|
||||
v = self.json(ensure_ascii=False)
|
||||
if self.cacheable:
|
||||
await redis.set(key=redis_key, data=v, timeout_sec=timeout_sec)
|
||||
logger.debug(f"REDIS SET {redis_key} {v}")
|
||||
self.is_dirty = False
|
||||
|
||||
@staticmethod
|
||||
def to_redis_key(prefix: str, user_id: str, chat_id: str):
|
||||
return f"{prefix}:{user_id}:{chat_id}"
|
||||
|
||||
async def set_history_summary(self, history_summary, redis_key, redis_conf):
|
||||
if self.historical_summary == history_summary:
|
||||
if self.is_dirty:
|
||||
await self.dumps(redis_key=redis_key, redis_conf=redis_conf)
|
||||
self.is_dirty = False
|
||||
return
|
||||
|
||||
self.historical_summary = history_summary
|
||||
self.history = []
|
||||
await self.dumps(redis_key=redis_key, redis_conf=redis_conf)
|
||||
self.is_dirty = False
|
||||
|
||||
def add_history(self, msg: Message):
|
||||
if msg.id:
|
||||
if self.to_int(msg.id, 0) <= self.to_int(self.last_history_id, -1):
|
||||
return
|
||||
self.history.append(msg.dict())
|
||||
self.last_history_id = str(msg.id)
|
||||
self.is_dirty = True
|
||||
|
||||
def exists(self, text) -> bool:
|
||||
for m in reversed(self.history):
|
||||
if m.get("content") == text:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def to_int(v, default_value):
|
||||
try:
|
||||
return int(v)
|
||||
except:
|
||||
return default_value
|
||||
|
||||
def pop_last_talk(self):
|
||||
v = self.last_talk
|
||||
self.last_talk = None
|
||||
return v
|
||||
|
||||
async def summarize(self, llm, max_words=200, keep_language: bool = False, limit: int = -1, **kwargs):
|
||||
if isinstance(llm, MetaGPTAPI):
|
||||
return await self._metagpt_summarize(max_words=max_words)
|
||||
|
||||
return await self._openai_summarize(llm=llm, max_words=max_words, keep_language=keep_language, limit=limit)
|
||||
|
||||
async def _openai_summarize(self, llm, max_words=200, keep_language: bool = False, limit: int = -1):
|
||||
texts = [self.historical_summary]
|
||||
for m in self.history:
|
||||
texts.append(m.content)
|
||||
text = "\n".join(texts)
|
||||
|
||||
text_length = len(text)
|
||||
if limit > 0 and text_length < limit:
|
||||
return text
|
||||
summary = await llm.summarize(text=text, max_words=max_words, keep_language=keep_language, limit=limit)
|
||||
if summary:
|
||||
await self.set_history_summary(history_summary=summary, redis_key=CONFIG.REDIS_KEY, redis_conf=CONFIG.REDIS)
|
||||
return summary
|
||||
raise ValueError(f"text too long:{text_length}")
|
||||
|
||||
async def _metagpt_summarize(self, max_words=200):
|
||||
if not self.history:
|
||||
return ""
|
||||
|
||||
total_length = 0
|
||||
msgs = []
|
||||
for m in reversed(self.history):
|
||||
delta = len(m.content)
|
||||
if total_length + delta > max_words:
|
||||
left = max_words - total_length
|
||||
if left == 0:
|
||||
break
|
||||
m.content = m.content[0:left]
|
||||
msgs.append(m.dict())
|
||||
break
|
||||
msgs.append(m)
|
||||
total_length += delta
|
||||
msgs.reverse()
|
||||
self.history = msgs
|
||||
self.is_dirty = True
|
||||
await self.dumps(redis_key=CONFIG.REDIS_KEY, redis_conf=CONFIG.REDIS_CONF)
|
||||
self.is_dirty = False
|
||||
|
||||
return BrainMemory.to_metagpt_history_format(self.history)
|
||||
|
||||
@staticmethod
|
||||
def to_metagpt_history_format(history) -> str:
|
||||
mmsg = [SimpleMessage(role=m.role, content=m.content) for m in history]
|
||||
return json.dumps(mmsg)
|
||||
|
||||
async def get_title(self, llm, max_words=5, **kwargs) -> str:
|
||||
"""Generate text title"""
|
||||
if isinstance(llm, MetaGPTAPI):
|
||||
return self.history[0].content if self.history else "New"
|
||||
|
||||
summary = await self.summarize(llm=llm, max_words=500)
|
||||
|
||||
language = CONFIG.language or DEFAULT_LANGUAGE
|
||||
command = f"Translate the above summary into a {language} title of less than {max_words} words."
|
||||
summaries = [summary, command]
|
||||
msg = "\n".join(summaries)
|
||||
logger.debug(f"title ask:{msg}")
|
||||
response = await llm.aask(msg=msg, system_msgs=[])
|
||||
logger.debug(f"title rsp: {response}")
|
||||
return response
|
||||
|
||||
async def is_related(self, text1, text2, llm):
|
||||
if isinstance(llm, MetaGPTAPI):
|
||||
return await self._metagpt_is_related(text1=text1, text2=text2, llm=llm)
|
||||
return await self._openai_is_related(text1=text1, text2=text2, llm=llm)
|
||||
|
||||
@staticmethod
|
||||
async def _metagpt_is_related(**kwargs):
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def _openai_is_related(text1, text2, llm, **kwargs):
|
||||
command = (
|
||||
f"{text2}\n\nIs there any sentence above related to the following sentence: {text1}.\nIf is there "
|
||||
"any relevance, return [TRUE] brief and clear. Otherwise, return [FALSE] brief and clear."
|
||||
)
|
||||
rsp = await llm.aask(msg=command, system_msgs=[])
|
||||
result = True if "TRUE" in rsp else False
|
||||
p2 = text2.replace("\n", "")
|
||||
p1 = text1.replace("\n", "")
|
||||
logger.info(f"IS_RELATED:\nParagraph 1: {p2}\nParagraph 2: {p1}\nRESULT: {result}\n")
|
||||
return result
|
||||
|
||||
async def rewrite(self, sentence: str, context: str, llm):
|
||||
if isinstance(llm, MetaGPTAPI):
|
||||
return await self._metagpt_rewrite(sentence=sentence, context=context, llm=llm)
|
||||
return await self._openai_rewrite(sentence=sentence, context=context, llm=llm)
|
||||
|
||||
@staticmethod
|
||||
async def _metagpt_rewrite(sentence: str):
|
||||
return sentence
|
||||
|
||||
@staticmethod
|
||||
async def _openai_rewrite(sentence: str, context: str, llm):
|
||||
command = (
|
||||
f"{context}\n\nExtract relevant information from every preceding sentence and use it to succinctly "
|
||||
f"supplement or rewrite the following text in brief and clear:\n{sentence}"
|
||||
)
|
||||
rsp = await llm.aask(msg=command, system_msgs=[])
|
||||
logger.info(f"REWRITE:\nCommand: {command}\nRESULT: {rsp}\n")
|
||||
return rsp
|
||||
|
||||
@staticmethod
|
||||
def extract_info(input_string, pattern=r"\[([A-Z]+)\]:\s*(.+)"):
|
||||
match = re.match(pattern, input_string)
|
||||
if match:
|
||||
return match.group(1), match.group(2)
|
||||
else:
|
||||
return None, input_string
|
||||
|
||||
@property
|
||||
def is_history_available(self):
|
||||
return bool(self.history or self.historical_summary)
|
||||
|
||||
@property
|
||||
def history_text(self):
|
||||
if len(self.history) == 0 and not self.historical_summary:
|
||||
return ""
|
||||
texts = [self.historical_summary] if self.historical_summary else []
|
||||
for m in self.history[:-1]:
|
||||
if isinstance(m, Dict):
|
||||
t = Message(**m).content
|
||||
elif isinstance(m, Message):
|
||||
t = m.content
|
||||
else:
|
||||
continue
|
||||
texts.append(t)
|
||||
|
||||
return "\n".join(texts)
|
||||
|
|
@ -2,6 +2,7 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Desc : the implement of Long-term memory
|
||||
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from typing import Iterable, Set
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.const import IGNORED_MESSAGE_ID
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import (
|
||||
any_to_str,
|
||||
|
|
@ -26,6 +27,7 @@ class Memory(BaseModel):
|
|||
|
||||
storage: list[Message] = []
|
||||
index: dict[str, list[Message]] = Field(default_factory=defaultdict(list))
|
||||
ignore_id: bool = False
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
index = kwargs.get("index", {})
|
||||
|
|
@ -54,6 +56,8 @@ class Memory(BaseModel):
|
|||
|
||||
def add(self, message: Message):
|
||||
"""Add a new message to storage, while updating the index"""
|
||||
if self.ignore_id:
|
||||
message.id = IGNORED_MESSAGE_ID
|
||||
if message in self.storage:
|
||||
return
|
||||
self.storage.append(message)
|
||||
|
|
@ -84,6 +88,8 @@ class Memory(BaseModel):
|
|||
|
||||
def delete(self, message: Message):
|
||||
"""Delete the specified message from storage, while updating the index"""
|
||||
if self.ignore_id:
|
||||
message.id = IGNORED_MESSAGE_ID
|
||||
self.storage.remove(message)
|
||||
if message.cause_by and message in self.index[message.cause_by]:
|
||||
self.index[message.cause_by].remove(message)
|
||||
|
|
|
|||
|
|
@ -1,11 +1,16 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the implement of memory storage
|
||||
"""
|
||||
@Desc : the implement of memory storage
|
||||
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.vectorstores.faiss import FAISS
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from metagpt.const import DATA_PATH, MEM_TTL
|
||||
from metagpt.document_store.faiss_store import FaissStore
|
||||
|
|
@ -19,20 +24,30 @@ class MemoryStorage(FaissStore):
|
|||
The memory storage with Faiss as ANN search engine
|
||||
"""
|
||||
|
||||
def __init__(self, mem_ttl: int = MEM_TTL):
|
||||
def __init__(self, mem_ttl: int = MEM_TTL, embedding: Embeddings = None):
|
||||
self.role_id: str = None
|
||||
self.role_mem_path: str = None
|
||||
self.mem_ttl: int = mem_ttl # later use
|
||||
self.threshold: float = 0.1 # experience value. TODO The threshold to filter similar memories
|
||||
self._initialized: bool = False
|
||||
|
||||
self.embedding = embedding or OpenAIEmbeddings()
|
||||
self.store: FAISS = None # Faiss engine
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
return self._initialized
|
||||
|
||||
def recover_memory(self, role_id: str) -> List[Message]:
|
||||
def _load(self) -> Optional["FaissStore"]:
|
||||
index_file, store_file = self._get_index_and_store_fname(index_ext=".faiss") # langchain FAISS using .faiss
|
||||
|
||||
if not (index_file.exists() and store_file.exists()):
|
||||
logger.info("Missing at least one of index_file/store_file, load failed and return None")
|
||||
return None
|
||||
|
||||
return FAISS.load_local(self.role_mem_path, self.embedding, self.role_id)
|
||||
|
||||
def recover_memory(self, role_id: str) -> list[Message]:
|
||||
self.role_id = role_id
|
||||
self.role_mem_path = Path(DATA_PATH / f"role_mem/{self.role_id}/")
|
||||
self.role_mem_path.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -49,16 +64,16 @@ class MemoryStorage(FaissStore):
|
|||
|
||||
return messages
|
||||
|
||||
def _get_index_and_store_fname(self):
|
||||
def _get_index_and_store_fname(self, index_ext=".index", pkl_ext=".pkl"):
|
||||
if not self.role_mem_path:
|
||||
logger.error(f"You should call {self.__class__.__name__}.recover_memory fist when using LongTermMemory")
|
||||
return None, None
|
||||
index_fpath = Path(self.role_mem_path / f"{self.role_id}.index")
|
||||
storage_fpath = Path(self.role_mem_path / f"{self.role_id}.pkl")
|
||||
index_fpath = Path(self.role_mem_path / f"{self.role_id}{index_ext}")
|
||||
storage_fpath = Path(self.role_mem_path / f"{self.role_id}{pkl_ext}")
|
||||
return index_fpath, storage_fpath
|
||||
|
||||
def persist(self):
|
||||
super().persist()
|
||||
self.store.save_local(self.role_mem_path, self.role_id)
|
||||
logger.debug(f"Agent {self.role_id} persist memory into local")
|
||||
|
||||
def add(self, message: Message) -> bool:
|
||||
|
|
@ -74,7 +89,7 @@ class MemoryStorage(FaissStore):
|
|||
self.persist()
|
||||
logger.info(f"Agent {self.role_id}'s memory_storage add a message")
|
||||
|
||||
def search_dissimilar(self, message: Message, k=4) -> List[Message]:
|
||||
def search_dissimilar(self, message: Message, k=4) -> list[Message]:
|
||||
"""search for dissimilar messages"""
|
||||
if not self.store:
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -12,5 +12,16 @@ from metagpt.provider.ollama_api import OllamaGPTAPI
|
|||
from metagpt.provider.open_llm_api import OpenLLMGPTAPI
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI
|
||||
from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI
|
||||
from metagpt.provider.azure_openai_api import AzureOpenAIGPTAPI
|
||||
from metagpt.provider.metagpt_api import MetaGPTAPI
|
||||
|
||||
__all__ = ["FireWorksGPTAPI", "GeminiGPTAPI", "OpenLLMGPTAPI", "OpenAIGPTAPI", "ZhiPuAIGPTAPI", "OllamaGPTAPI"]
|
||||
__all__ = [
|
||||
"FireWorksGPTAPI",
|
||||
"GeminiGPTAPI",
|
||||
"OpenLLMGPTAPI",
|
||||
"OpenAIGPTAPI",
|
||||
"ZhiPuAIGPTAPI",
|
||||
"AzureOpenAIGPTAPI",
|
||||
"MetaGPTAPI",
|
||||
"OllamaGPTAPI",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -7,13 +7,13 @@
|
|||
"""
|
||||
|
||||
import anthropic
|
||||
from anthropic import Anthropic
|
||||
from anthropic import Anthropic, AsyncAnthropic
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
|
||||
|
||||
class Claude2:
|
||||
def ask(self, prompt):
|
||||
def ask(self, prompt: str) -> str:
|
||||
client = Anthropic(api_key=CONFIG.anthropic_api_key)
|
||||
|
||||
res = client.completions.create(
|
||||
|
|
@ -23,10 +23,10 @@ class Claude2:
|
|||
)
|
||||
return res.completion
|
||||
|
||||
async def aask(self, prompt):
|
||||
client = Anthropic(api_key=CONFIG.anthropic_api_key)
|
||||
async def aask(self, prompt: str) -> str:
|
||||
aclient = AsyncAnthropic(api_key=CONFIG.anthropic_api_key)
|
||||
|
||||
res = client.completions.create(
|
||||
res = await aclient.completions.create(
|
||||
model="claude-2",
|
||||
prompt=f"{anthropic.HUMAN_PROMPT} {prompt} {anthropic.AI_PROMPT}",
|
||||
max_tokens_to_sample=1000,
|
||||
|
|
|
|||
69
metagpt/provider/azure_openai_api.py
Normal file
69
metagpt/provider/azure_openai_api.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/5/5 23:08
|
||||
@Author : alexanderwu
|
||||
@File : openai.py
|
||||
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation;
|
||||
Change cost control from global to company level.
|
||||
@Modified By: mashenquan, 2023/11/21. Fix bug: ReadTimeout.
|
||||
@Modified By: mashenquan, 2023/12/1. Fix bug: Unclosed connection caused by openai 0.x.
|
||||
"""
|
||||
|
||||
|
||||
from openai import AsyncAzureOpenAI, AzureOpenAI
|
||||
from openai._base_client import AsyncHttpxClientWrapper, SyncHttpxClientWrapper
|
||||
|
||||
from metagpt.config import CONFIG, Config, LLMProviderEnum
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI, RateLimiter
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.AZURE_OPENAI)
|
||||
class AzureOpenAIGPTAPI(OpenAIGPTAPI):
|
||||
"""
|
||||
Check https://platform.openai.com/examples for examples
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.config: Config = CONFIG
|
||||
self._init_openai()
|
||||
self.auto_max_tokens = False
|
||||
RateLimiter.__init__(self, rpm=self.rpm)
|
||||
|
||||
def _make_client(self):
|
||||
kwargs, async_kwargs = self._make_client_kwargs()
|
||||
# https://learn.microsoft.com/zh-cn/azure/ai-services/openai/how-to/migration?tabs=python-new%2Cdalle-fix
|
||||
self.client = AzureOpenAI(**kwargs)
|
||||
self.async_client = AsyncAzureOpenAI(**async_kwargs)
|
||||
self.model = self.config.DEPLOYMENT_NAME # Used in _calc_usage & _cons_kwargs
|
||||
|
||||
def _make_client_kwargs(self) -> (dict, dict):
|
||||
kwargs = dict(
|
||||
api_key=self.config.OPENAI_API_KEY,
|
||||
api_version=self.config.OPENAI_API_VERSION,
|
||||
azure_endpoint=self.config.OPENAI_BASE_URL,
|
||||
)
|
||||
async_kwargs = kwargs.copy()
|
||||
|
||||
# to use proxy, openai v1 needs http_client
|
||||
proxy_params = self._get_proxy_params()
|
||||
if proxy_params:
|
||||
kwargs["http_client"] = SyncHttpxClientWrapper(**proxy_params)
|
||||
async_kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params)
|
||||
|
||||
return kwargs, async_kwargs
|
||||
|
||||
def _cons_kwargs(self, messages: list[dict], timeout=3, **configs) -> dict:
|
||||
kwargs = {
|
||||
"messages": messages,
|
||||
"max_tokens": self.get_max_tokens(messages),
|
||||
"n": 1,
|
||||
"stop": None,
|
||||
"temperature": 0.3,
|
||||
"model": self.model,
|
||||
}
|
||||
if configs:
|
||||
kwargs.update(configs)
|
||||
kwargs["timeout"] = max(CONFIG.timeout, timeout)
|
||||
|
||||
return kwargs
|
||||
|
|
@ -4,6 +4,7 @@
|
|||
@Time : 2023/5/5 23:00
|
||||
@Author : alexanderwu
|
||||
@File : base_chatbot.py
|
||||
@Modified By: mashenquan, 2023/11/21. Add `timeout`.
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
|
@ -17,13 +18,13 @@ class BaseChatbot(ABC):
|
|||
use_system_prompt: bool = True
|
||||
|
||||
@abstractmethod
|
||||
def ask(self, msg: str) -> str:
|
||||
def ask(self, msg: str, timeout=3) -> str:
|
||||
"""Ask GPT a question and get an answer"""
|
||||
|
||||
@abstractmethod
|
||||
def ask_batch(self, msgs: list) -> str:
|
||||
def ask_batch(self, msgs: list, timeout=3) -> str:
|
||||
"""Ask GPT multiple questions and get a series of answers"""
|
||||
|
||||
@abstractmethod
|
||||
def ask_code(self, msgs: list) -> str:
|
||||
def ask_code(self, msgs: list, timeout=3) -> str:
|
||||
"""Ask GPT multiple questions and get a piece of code"""
|
||||
|
|
|
|||
|
|
@ -4,12 +4,12 @@
|
|||
@Time : 2023/5/5 23:04
|
||||
@Author : alexanderwu
|
||||
@File : base_gpt_api.py
|
||||
@Desc : mashenquan, 2023/8/22. + try catch
|
||||
"""
|
||||
import json
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_chatbot import BaseChatbot
|
||||
|
||||
|
||||
|
|
@ -33,62 +33,65 @@ class BaseGPTAPI(BaseChatbot):
|
|||
def _default_system_msg(self):
|
||||
return self._system_msg(self.system_prompt)
|
||||
|
||||
def ask(self, msg: str) -> str:
|
||||
def ask(self, msg: str, timeout=3) -> str:
|
||||
message = [self._default_system_msg(), self._user_msg(msg)] if self.use_system_prompt else [self._user_msg(msg)]
|
||||
rsp = self.completion(message)
|
||||
rsp = self.completion(message, timeout=timeout)
|
||||
return self.get_choice_text(rsp)
|
||||
|
||||
async def aask(self, msg: str, system_msgs: Optional[list[str]] = None, stream=True) -> str:
|
||||
async def aask(
|
||||
self,
|
||||
msg: str,
|
||||
system_msgs: Optional[list[str]] = None,
|
||||
format_msgs: Optional[list[dict[str, str]]] = None,
|
||||
timeout=3,
|
||||
stream=True,
|
||||
) -> str:
|
||||
if system_msgs:
|
||||
message = (
|
||||
self._system_msgs(system_msgs) + [self._user_msg(msg)]
|
||||
if self.use_system_prompt
|
||||
else [self._user_msg(msg)]
|
||||
)
|
||||
message = self._system_msgs(system_msgs)
|
||||
else:
|
||||
message = (
|
||||
[self._default_system_msg(), self._user_msg(msg)] if self.use_system_prompt else [self._user_msg(msg)]
|
||||
)
|
||||
logger.debug(message)
|
||||
rsp = await self.acompletion_text(message, stream=stream)
|
||||
message = [self._default_system_msg()]
|
||||
if format_msgs:
|
||||
message.extend(format_msgs)
|
||||
message.append(self._user_msg(msg))
|
||||
rsp = await self.acompletion_text(message, stream=stream, timeout=timeout)
|
||||
# logger.debug(rsp)
|
||||
return rsp
|
||||
|
||||
def _extract_assistant_rsp(self, context):
|
||||
return "\n".join([i["content"] for i in context if i["role"] == "assistant"])
|
||||
|
||||
def ask_batch(self, msgs: list) -> str:
|
||||
def ask_batch(self, msgs: list, timeout=3) -> str:
|
||||
context = []
|
||||
for msg in msgs:
|
||||
umsg = self._user_msg(msg)
|
||||
context.append(umsg)
|
||||
rsp = self.completion(context)
|
||||
rsp = self.completion(context, timeout=timeout)
|
||||
rsp_text = self.get_choice_text(rsp)
|
||||
context.append(self._assistant_msg(rsp_text))
|
||||
return self._extract_assistant_rsp(context)
|
||||
|
||||
async def aask_batch(self, msgs: list) -> str:
|
||||
async def aask_batch(self, msgs: list, timeout=3) -> str:
|
||||
"""Sequential questioning"""
|
||||
context = []
|
||||
for msg in msgs:
|
||||
umsg = self._user_msg(msg)
|
||||
context.append(umsg)
|
||||
rsp_text = await self.acompletion_text(context)
|
||||
rsp_text = await self.acompletion_text(context, timeout=timeout)
|
||||
context.append(self._assistant_msg(rsp_text))
|
||||
return self._extract_assistant_rsp(context)
|
||||
|
||||
def ask_code(self, msgs: list[str]) -> str:
|
||||
def ask_code(self, msgs: list[str], timeout=3) -> str:
|
||||
"""FIXME: No code segment filtering has been done here, and all results are actually displayed"""
|
||||
rsp_text = self.ask_batch(msgs)
|
||||
rsp_text = self.ask_batch(msgs, timeout=timeout)
|
||||
return rsp_text
|
||||
|
||||
async def aask_code(self, msgs: list[str]) -> str:
|
||||
async def aask_code(self, msgs: list[str], timeout=3) -> str:
|
||||
"""FIXME: No code segment filtering has been done here, and all results are actually displayed"""
|
||||
rsp_text = await self.aask_batch(msgs)
|
||||
rsp_text = await self.aask_batch(msgs, timeout=timeout)
|
||||
return rsp_text
|
||||
|
||||
@abstractmethod
|
||||
def completion(self, messages: list[dict]):
|
||||
def completion(self, messages: list[dict], timeout=3):
|
||||
"""All GPTAPIs are required to provide the standard OpenAI completion interface
|
||||
[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
|
|
@ -98,7 +101,7 @@ class BaseGPTAPI(BaseChatbot):
|
|||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def acompletion(self, messages: list[dict]):
|
||||
async def acompletion(self, messages: list[dict], timeout=3):
|
||||
"""Asynchronous version of completion
|
||||
All GPTAPIs are required to provide the standard OpenAI completion interface
|
||||
[
|
||||
|
|
@ -109,7 +112,7 @@ class BaseGPTAPI(BaseChatbot):
|
|||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str:
|
||||
"""Asynchronous version of completion. Return str. Support stream-print"""
|
||||
|
||||
def get_choice_text(self, rsp: dict) -> str:
|
||||
|
|
@ -145,7 +148,7 @@ class BaseGPTAPI(BaseChatbot):
|
|||
:return dict: return first function of choice, for exmaple,
|
||||
{'name': 'execute', 'arguments': '{\n "language": "python",\n "code": "print(\'Hello, World!\')"\n}'}
|
||||
"""
|
||||
return rsp.get("choices")[0]["message"]["tool_calls"][0]["function"].to_dict()
|
||||
return rsp.get("choices")[0]["message"]["tool_calls"][0]["function"]
|
||||
|
||||
def get_choice_function_arguments(self, rsp: dict) -> dict:
|
||||
"""Required to provide the first function arguments of choice.
|
||||
|
|
@ -158,8 +161,13 @@ class BaseGPTAPI(BaseChatbot):
|
|||
|
||||
def messages_to_prompt(self, messages: list[dict]):
|
||||
"""[{"role": "user", "content": msg}] to user: <msg> etc."""
|
||||
return "\n".join([f"{i['role']}: {i['content']}" for i in messages])
|
||||
return "\n".join([f"{i.role}: {i.content}" for i in messages])
|
||||
|
||||
def messages_to_dict(self, messages):
|
||||
"""objects to [{"role": "user", "content": msg}] etc."""
|
||||
return [i.to_dict() for i in messages]
|
||||
|
||||
@abstractmethod
|
||||
async def close(self):
|
||||
"""Close connection"""
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -2,24 +2,142 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc : fireworks.ai's api
|
||||
|
||||
import openai
|
||||
import re
|
||||
|
||||
from metagpt.config import CONFIG, LLMProviderEnum
|
||||
from openai import APIConnectionError, AsyncStream
|
||||
from openai.types import CompletionUsage
|
||||
from openai.types.chat import ChatCompletionChunk
|
||||
from tenacity import (
|
||||
after_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from metagpt.config import CONFIG, Config, LLMProviderEnum
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import CostManager, OpenAIGPTAPI, RateLimiter
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI, RateLimiter, log_and_reraise
|
||||
from metagpt.utils.cost_manager import CostManager, Costs
|
||||
|
||||
MODEL_GRADE_TOKEN_COSTS = {
|
||||
"-1": {"prompt": 0.0, "completion": 0.0}, # abnormal condition
|
||||
"16": {"prompt": 0.2, "completion": 0.8}, # 16 means model size <= 16B; 0.2 means $0.2/1M tokens
|
||||
"80": {"prompt": 0.7, "completion": 2.8}, # 80 means 16B < model size <= 80B
|
||||
"mixtral-8x7b": {"prompt": 0.4, "completion": 1.6},
|
||||
}
|
||||
|
||||
|
||||
class FireworksCostManager(CostManager):
|
||||
def model_grade_token_costs(self, model: str) -> dict[str, float]:
|
||||
def _get_model_size(model: str) -> float:
|
||||
size = re.findall(".*-([0-9.]+)b", model)
|
||||
size = float(size[0]) if len(size) > 0 else -1
|
||||
return size
|
||||
|
||||
if "mixtral-8x7b" in model:
|
||||
token_costs = MODEL_GRADE_TOKEN_COSTS["mixtral-8x7b"]
|
||||
else:
|
||||
model_size = _get_model_size(model)
|
||||
if 0 < model_size <= 16:
|
||||
token_costs = MODEL_GRADE_TOKEN_COSTS["16"]
|
||||
elif 16 < model_size <= 80:
|
||||
token_costs = MODEL_GRADE_TOKEN_COSTS["80"]
|
||||
else:
|
||||
token_costs = MODEL_GRADE_TOKEN_COSTS["-1"]
|
||||
return token_costs
|
||||
|
||||
def update_cost(self, prompt_tokens: int, completion_tokens: int, model: str):
|
||||
"""
|
||||
Refs to `https://app.fireworks.ai/pricing` **Developer pricing**
|
||||
Update the total cost, prompt tokens, and completion tokens.
|
||||
|
||||
Args:
|
||||
prompt_tokens (int): The number of tokens used in the prompt.
|
||||
completion_tokens (int): The number of tokens used in the completion.
|
||||
model (str): The model used for the API call.
|
||||
"""
|
||||
self.total_prompt_tokens += prompt_tokens
|
||||
self.total_completion_tokens += completion_tokens
|
||||
|
||||
token_costs = self.model_grade_token_costs(model)
|
||||
cost = (prompt_tokens * token_costs["prompt"] + completion_tokens * token_costs["completion"]) / 1000000
|
||||
self.total_cost += cost
|
||||
logger.info(
|
||||
f"Total running cost: ${self.total_cost:.4f} | Max budget: ${CONFIG.max_budget:.3f} | "
|
||||
f"Current cost: ${cost:.4f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
|
||||
)
|
||||
CONFIG.total_cost = self.total_cost
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.FIREWORKS)
|
||||
class FireWorksGPTAPI(OpenAIGPTAPI):
|
||||
def __init__(self):
|
||||
self.__init_fireworks(CONFIG)
|
||||
self.llm = openai
|
||||
self.model = CONFIG.fireworks_api_model
|
||||
self.config: Config = CONFIG
|
||||
self.__init_fireworks()
|
||||
self.auto_max_tokens = False
|
||||
self._cost_manager = CostManager()
|
||||
self._cost_manager = FireworksCostManager()
|
||||
RateLimiter.__init__(self, rpm=self.rpm)
|
||||
|
||||
def __init_fireworks(self, config: "Config"):
|
||||
openai.api_key = config.fireworks_api_key
|
||||
openai.api_base = config.fireworks_api_base
|
||||
self.rpm = int(config.get("RPM", 10))
|
||||
def __init_fireworks(self):
|
||||
self.is_azure = False
|
||||
self.rpm = int(self.config.get("RPM", 10))
|
||||
self._make_client()
|
||||
self.model = self.config.fireworks_api_model # `self.model` should after `_make_client` to rewrite it
|
||||
|
||||
def _make_client_kwargs(self) -> (dict, dict):
|
||||
kwargs = dict(api_key=self.config.fireworks_api_key, base_url=self.config.fireworks_api_base)
|
||||
async_kwargs = kwargs.copy()
|
||||
return kwargs, async_kwargs
|
||||
|
||||
def _update_costs(self, usage: CompletionUsage):
|
||||
if self.config.calc_usage and usage:
|
||||
try:
|
||||
# use FireworksCostManager not CONFIG.cost_manager
|
||||
self._cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"updating costs failed!, exp: {e}")
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
return self._cost_manager.get_costs()
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict]) -> str:
|
||||
response: AsyncStream[ChatCompletionChunk] = await self.async_client.chat.completions.create(
|
||||
**self._cons_kwargs(messages), stream=True
|
||||
)
|
||||
|
||||
collected_content = []
|
||||
usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
|
||||
# iterate through the stream of events
|
||||
async for chunk in response:
|
||||
if chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
choice_delta = choice.delta
|
||||
finish_reason = choice.finish_reason if hasattr(choice, "finish_reason") else None
|
||||
if choice_delta.content:
|
||||
collected_content.append(choice_delta.content)
|
||||
print(choice_delta.content, end="")
|
||||
if finish_reason:
|
||||
# fireworks api return usage when finish_reason is not None
|
||||
usage = CompletionUsage(**chunk.usage)
|
||||
|
||||
full_content = "".join(collected_content)
|
||||
self._update_costs(usage)
|
||||
return full_content
|
||||
|
||||
@retry(
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
stop=stop_after_attempt(6),
|
||||
after=after_log(logger, logger.level("WARNING").name),
|
||||
retry=retry_if_exception_type(APIConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(
|
||||
self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3
|
||||
) -> str:
|
||||
"""when streaming, print each token in place."""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
rsp = await self._achat_completion(messages)
|
||||
return self.get_choice_text(rsp)
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ from metagpt.config import CONFIG, LLMProviderEnum
|
|||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import CostManager, log_and_reraise
|
||||
from metagpt.provider.openai_api import log_and_reraise
|
||||
|
||||
|
||||
class GeminiGenerativeModel(GenerativeModel):
|
||||
|
|
@ -53,7 +53,6 @@ class GeminiGPTAPI(BaseGPTAPI):
|
|||
self.__init_gemini(CONFIG)
|
||||
self.model = "gemini-pro" # so far only one model
|
||||
self.llm = GeminiGenerativeModel(model_name=self.model)
|
||||
self._cost_manager = CostManager()
|
||||
|
||||
def __init_gemini(self, config: CONFIG):
|
||||
genai.configure(api_key=config.gemini_api_key)
|
||||
|
|
@ -76,10 +75,13 @@ class GeminiGPTAPI(BaseGPTAPI):
|
|||
try:
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0))
|
||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
CONFIG.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"google gemini updats costs failed! exp: {e}")
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def get_choice_text(self, resp: GenerateContentResponse) -> str:
|
||||
return resp.text
|
||||
|
||||
|
|
@ -134,7 +136,9 @@ class GeminiGPTAPI(BaseGPTAPI):
|
|||
retry=retry_if_exception_type(ConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
|
||||
async def acompletion_text(
|
||||
self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3
|
||||
) -> str:
|
||||
"""response in async with stream or non-stream mode"""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
|
|
|
|||
|
|
@ -14,24 +14,35 @@ class HumanProvider(BaseGPTAPI):
|
|||
This enables replacing LLM anywhere in the framework with a human, thus introducing human interaction
|
||||
"""
|
||||
|
||||
def ask(self, msg: str) -> str:
|
||||
def ask(self, msg: str, timeout=3) -> str:
|
||||
logger.info("It's your turn, please type in your response. You may also refer to the context below")
|
||||
rsp = input(msg)
|
||||
if rsp in ["exit", "quit"]:
|
||||
exit()
|
||||
return rsp
|
||||
|
||||
async def aask(self, msg: str, system_msgs: Optional[list[str]] = None) -> str:
|
||||
return self.ask(msg)
|
||||
async def aask(
|
||||
self,
|
||||
msg: str,
|
||||
system_msgs: Optional[list[str]] = None,
|
||||
format_msgs: Optional[list[dict[str, str]]] = None,
|
||||
generator: bool = False,
|
||||
timeout=3,
|
||||
) -> str:
|
||||
return self.ask(msg, timeout=timeout)
|
||||
|
||||
def completion(self, messages: list[dict]):
|
||||
def completion(self, messages: list[dict], timeout=3):
|
||||
"""dummy implementation of abstract method in base"""
|
||||
return []
|
||||
|
||||
async def acompletion(self, messages: list[dict]):
|
||||
async def acompletion(self, messages: list[dict], timeout=3):
|
||||
"""dummy implementation of abstract method in base"""
|
||||
return []
|
||||
|
||||
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
|
||||
"""dummy implementation of abstract method in base"""
|
||||
return []
|
||||
return ""
|
||||
|
||||
async def close(self):
|
||||
"""Close connection"""
|
||||
pass
|
||||
|
|
|
|||
16
metagpt/provider/metagpt_api.py
Normal file
16
metagpt/provider/metagpt_api.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/5/5 23:08
|
||||
@Author : alexanderwu
|
||||
@File : metagpt_api.py
|
||||
@Desc : MetaGPT LLM provider.
|
||||
"""
|
||||
from metagpt.config import LLMProviderEnum
|
||||
from metagpt.provider import OpenAIGPTAPI
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.METAGPT)
|
||||
class MetaGPTAPI(OpenAIGPTAPI):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
@ -19,7 +19,8 @@ from metagpt.logs import log_llm_stream, logger
|
|||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.general_api_requestor import GeneralAPIRequestor
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import CostManager, log_and_reraise
|
||||
from metagpt.provider.openai_api import log_and_reraise
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
|
||||
|
||||
class OllamaCostManager(CostManager):
|
||||
|
|
@ -56,6 +57,9 @@ class OllamaGPTAPI(BaseGPTAPI):
|
|||
|
||||
self.model = config.ollama_api_model
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
|
||||
kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream}
|
||||
return kwargs
|
||||
|
|
@ -143,7 +147,9 @@ class OllamaGPTAPI(BaseGPTAPI):
|
|||
retry=retry_if_exception_type(ConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
|
||||
async def acompletion_text(
|
||||
self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3
|
||||
) -> str:
|
||||
"""response in async with stream or non-stream mode"""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
|
|
|
|||
|
|
@ -2,12 +2,14 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc : self-host open llm model with openai-compatible interface
|
||||
|
||||
import openai
|
||||
from openai.types import CompletionUsage
|
||||
|
||||
from metagpt.config import CONFIG, LLMProviderEnum
|
||||
from metagpt.config import CONFIG, Config, LLMProviderEnum
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import CostManager, OpenAIGPTAPI, RateLimiter
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI, RateLimiter
|
||||
from metagpt.utils.cost_manager import CostManager, Costs
|
||||
from metagpt.utils.token_counter import count_message_tokens, count_string_tokens
|
||||
|
||||
|
||||
class OpenLLMCostManager(CostManager):
|
||||
|
|
@ -26,7 +28,7 @@ class OpenLLMCostManager(CostManager):
|
|||
self.total_completion_tokens += completion_tokens
|
||||
|
||||
logger.info(
|
||||
f"Max budget: ${CONFIG.max_budget:.3f} | "
|
||||
f"Max budget: ${CONFIG.max_budget:.3f} | reference "
|
||||
f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
|
||||
)
|
||||
CONFIG.total_cost = self.total_cost
|
||||
|
|
@ -35,14 +37,43 @@ class OpenLLMCostManager(CostManager):
|
|||
@register_provider(LLMProviderEnum.OPEN_LLM)
|
||||
class OpenLLMGPTAPI(OpenAIGPTAPI):
|
||||
def __init__(self):
|
||||
self.__init_openllm(CONFIG)
|
||||
self.llm = openai
|
||||
self.model = CONFIG.open_llm_api_model
|
||||
self.config: Config = CONFIG
|
||||
self.__init_openllm()
|
||||
self.auto_max_tokens = False
|
||||
self._cost_manager = OpenLLMCostManager()
|
||||
RateLimiter.__init__(self, rpm=self.rpm)
|
||||
|
||||
def __init_openllm(self, config: "Config"):
|
||||
openai.api_key = "sk-xx" # self-host api doesn't need api-key, use the default value
|
||||
openai.api_base = config.open_llm_api_base
|
||||
self.rpm = int(config.get("RPM", 10))
|
||||
def __init_openllm(self):
|
||||
self.is_azure = False
|
||||
self.rpm = int(self.config.get("RPM", 10))
|
||||
self._make_client()
|
||||
self.model = self.config.open_llm_api_model # `self.model` should after `_make_client` to rewrite it
|
||||
|
||||
def _make_client_kwargs(self) -> (dict, dict):
|
||||
kwargs = dict(api_key="sk-xxx", base_url=self.config.open_llm_api_base)
|
||||
async_kwargs = kwargs.copy()
|
||||
return kwargs, async_kwargs
|
||||
|
||||
def _calc_usage(self, messages: list[dict], rsp: str) -> CompletionUsage:
|
||||
usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
|
||||
if not CONFIG.calc_usage:
|
||||
return usage
|
||||
|
||||
try:
|
||||
usage.prompt_tokens = count_message_tokens(messages, "open-llm-model")
|
||||
usage.completion_tokens = count_string_tokens(rsp, "open-llm-model")
|
||||
except Exception as e:
|
||||
logger.error(f"usage calculation failed!: {e}")
|
||||
|
||||
return usage
|
||||
|
||||
def _update_costs(self, usage: CompletionUsage):
|
||||
if self.config.calc_usage and usage:
|
||||
try:
|
||||
# use OpenLLMCostManager not CONFIG.cost_manager
|
||||
self._cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"updating costs failed!, exp: {e}")
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
return self._cost_manager.get_costs()
|
||||
|
|
|
|||
|
|
@ -3,20 +3,19 @@
|
|||
@Time : 2023/5/5 23:08
|
||||
@Author : alexanderwu
|
||||
@File : openai.py
|
||||
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation;
|
||||
Change cost control from global to company level.
|
||||
@Modified By: mashenquan, 2023/11/21. Fix bug: ReadTimeout.
|
||||
@Modified By: mashenquan, 2023/12/1. Fix bug: Unclosed connection caused by openai 0.x.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import NamedTuple, Union
|
||||
from typing import AsyncIterator, List, Union
|
||||
|
||||
from openai import (
|
||||
APIConnectionError,
|
||||
AsyncAzureOpenAI,
|
||||
AsyncOpenAI,
|
||||
AsyncStream,
|
||||
AzureOpenAI,
|
||||
OpenAI,
|
||||
)
|
||||
import openai
|
||||
from openai import APIConnectionError, AsyncOpenAI, AsyncStream, OpenAI
|
||||
from openai._base_client import AsyncHttpxClientWrapper, SyncHttpxClientWrapper
|
||||
from openai.types import CompletionUsage
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
|
|
@ -29,15 +28,15 @@ from tenacity import (
|
|||
)
|
||||
|
||||
from metagpt.config import CONFIG, Config, LLMProviderEnum
|
||||
from metagpt.const import DEFAULT_MAX_TOKENS, DEFAULT_TOKEN_SIZE
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA, GENERAL_TOOL_CHOICE
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.cost_manager import Costs
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
from metagpt.utils.singleton import Singleton
|
||||
from metagpt.utils.token_counter import (
|
||||
TOKEN_COSTS,
|
||||
count_message_tokens,
|
||||
count_string_tokens,
|
||||
get_max_completion_tokens,
|
||||
|
|
@ -69,75 +68,6 @@ class RateLimiter:
|
|||
self.last_call_time = time.time()
|
||||
|
||||
|
||||
class Costs(NamedTuple):
|
||||
total_prompt_tokens: int
|
||||
total_completion_tokens: int
|
||||
total_cost: float
|
||||
total_budget: float
|
||||
|
||||
|
||||
class CostManager(metaclass=Singleton):
|
||||
"""计算使用接口的开销"""
|
||||
|
||||
def __init__(self):
|
||||
self.total_prompt_tokens = 0
|
||||
self.total_completion_tokens = 0
|
||||
self.total_cost = 0
|
||||
self.total_budget = 0
|
||||
|
||||
def update_cost(self, prompt_tokens, completion_tokens, model):
|
||||
"""
|
||||
Update the total cost, prompt tokens, and completion tokens.
|
||||
|
||||
Args:
|
||||
prompt_tokens (int): The number of tokens used in the prompt.
|
||||
completion_tokens (int): The number of tokens used in the completion.
|
||||
model (str): The model used for the API call.
|
||||
"""
|
||||
self.total_prompt_tokens += prompt_tokens
|
||||
self.total_completion_tokens += completion_tokens
|
||||
cost = (
|
||||
prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model]["completion"]
|
||||
) / 1000
|
||||
self.total_cost += cost
|
||||
logger.info(
|
||||
f"Total running cost: ${self.total_cost:.3f} | Max budget: ${CONFIG.max_budget:.3f} | "
|
||||
f"Current cost: ${cost:.3f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
|
||||
)
|
||||
CONFIG.total_cost = self.total_cost
|
||||
|
||||
def get_total_prompt_tokens(self):
|
||||
"""
|
||||
Get the total number of prompt tokens.
|
||||
|
||||
Returns:
|
||||
int: The total number of prompt tokens.
|
||||
"""
|
||||
return self.total_prompt_tokens
|
||||
|
||||
def get_total_completion_tokens(self):
|
||||
"""
|
||||
Get the total number of completion tokens.
|
||||
|
||||
Returns:
|
||||
int: The total number of completion tokens.
|
||||
"""
|
||||
return self.total_completion_tokens
|
||||
|
||||
def get_total_cost(self):
|
||||
"""
|
||||
Get the total cost of API calls.
|
||||
|
||||
Returns:
|
||||
float: The total cost of API calls.
|
||||
"""
|
||||
return self.total_cost
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
"""Get all costs"""
|
||||
return Costs(self.total_prompt_tokens, self.total_completion_tokens, self.total_cost, self.total_budget)
|
||||
|
||||
|
||||
def log_and_reraise(retry_state):
|
||||
logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}")
|
||||
logger.warning(
|
||||
|
|
@ -157,37 +87,23 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
|
||||
def __init__(self):
|
||||
self.config: Config = CONFIG
|
||||
self.__init_openai()
|
||||
self._init_openai()
|
||||
self.auto_max_tokens = False
|
||||
self._cost_manager = CostManager()
|
||||
RateLimiter.__init__(self, rpm=self.rpm)
|
||||
|
||||
def __init_openai(self):
|
||||
self.is_azure = self.config.openai_api_type == "azure"
|
||||
self.model = self.config.deployment_name if self.is_azure else self.config.openai_api_model
|
||||
self.rpm = int(self.config.get("RPM", 10))
|
||||
def _init_openai(self):
|
||||
self.rpm = int(self.config.RPM or 10)
|
||||
self._make_client()
|
||||
|
||||
def _make_client(self):
|
||||
kwargs, async_kwargs = self._make_client_kwargs()
|
||||
|
||||
if self.is_azure:
|
||||
self.client = AzureOpenAI(**kwargs)
|
||||
self.async_client = AsyncAzureOpenAI(**async_kwargs)
|
||||
else:
|
||||
self.client = OpenAI(**kwargs)
|
||||
self.async_client = AsyncOpenAI(**async_kwargs)
|
||||
# https://github.com/openai/openai-python#async-usage
|
||||
self.client = OpenAI(**kwargs)
|
||||
self.async_client = AsyncOpenAI(**async_kwargs)
|
||||
self.model = self.config.OPENAI_API_MODEL # Used in _calc_usage & _cons_kwargs
|
||||
|
||||
def _make_client_kwargs(self) -> (dict, dict):
|
||||
if self.is_azure:
|
||||
kwargs = dict(
|
||||
api_key=self.config.openai_api_key,
|
||||
api_version=self.config.openai_api_version,
|
||||
azure_endpoint=self.config.openai_base_url,
|
||||
)
|
||||
else:
|
||||
kwargs = dict(api_key=self.config.openai_api_key, base_url=self.config.openai_base_url)
|
||||
|
||||
kwargs = dict(api_key=self.config.OPENAI_API_KEY, base_url=self.config.OPENAI_BASE_URL)
|
||||
async_kwargs = kwargs.copy()
|
||||
|
||||
# to use proxy, openai v1 needs http_client
|
||||
|
|
@ -202,64 +118,51 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
params = {}
|
||||
if self.config.openai_proxy:
|
||||
params = {"proxies": self.config.openai_proxy}
|
||||
if self.config.openai_base_url:
|
||||
params["base_url"] = self.config.openai_base_url
|
||||
if self.config.OPENAI_BASE_URL:
|
||||
params["base_url"] = self.config.OPENAI_BASE_URL
|
||||
|
||||
return params
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict]) -> str:
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> AsyncIterator[str]:
|
||||
response: AsyncStream[ChatCompletionChunk] = await self.async_client.chat.completions.create(
|
||||
**self._cons_kwargs(messages), stream=True
|
||||
**self._cons_kwargs(messages, timeout=timeout), stream=True
|
||||
)
|
||||
|
||||
# create variables to collect the stream of chunks
|
||||
collected_chunks = []
|
||||
collected_messages = []
|
||||
# iterate through the stream of events
|
||||
async for chunk in response:
|
||||
collected_chunks.append(chunk) # save the event response
|
||||
if chunk.choices:
|
||||
chunk_message = chunk.choices[0].delta # extract the message
|
||||
collected_messages.append(chunk_message) # save the message
|
||||
if chunk_message.content:
|
||||
log_llm_stream(chunk_message.content)
|
||||
print()
|
||||
chunk_message = chunk.choices[0].delta.content or "" if chunk.choices else "" # extract the message
|
||||
yield chunk_message
|
||||
|
||||
full_reply_content = "".join([m.content for m in collected_messages if m.content])
|
||||
usage = self._calc_usage(messages, full_reply_content)
|
||||
self._update_costs(usage)
|
||||
return full_reply_content
|
||||
|
||||
def _cons_kwargs(self, messages: list[dict], **configs) -> dict:
|
||||
def _cons_kwargs(self, messages: list[dict], timeout=3, **configs) -> dict:
|
||||
kwargs = {
|
||||
"messages": messages,
|
||||
"max_tokens": self.get_max_tokens(messages),
|
||||
"n": 1,
|
||||
"stop": None,
|
||||
"temperature": 0.3,
|
||||
"timeout": 3,
|
||||
"model": self.model,
|
||||
}
|
||||
if configs:
|
||||
kwargs.update(configs)
|
||||
kwargs["timeout"] = max(CONFIG.timeout, timeout)
|
||||
|
||||
return kwargs
|
||||
|
||||
async def _achat_completion(self, messages: list[dict]) -> ChatCompletion:
|
||||
rsp: ChatCompletion = await self.async_client.chat.completions.create(**self._cons_kwargs(messages))
|
||||
async def _achat_completion(self, messages: list[dict], timeout=3) -> ChatCompletion:
|
||||
kwargs = self._cons_kwargs(messages, timeout=timeout)
|
||||
rsp: ChatCompletion = await self.async_client.chat.completions.create(**kwargs)
|
||||
self._update_costs(rsp.usage)
|
||||
return rsp
|
||||
|
||||
def _chat_completion(self, messages: list[dict]) -> ChatCompletion:
|
||||
rsp: ChatCompletion = self.client.chat.completions.create(**self._cons_kwargs(messages))
|
||||
def _chat_completion(self, messages: list[dict], timeout=3) -> ChatCompletion:
|
||||
rsp: ChatCompletion = self.client.chat.completions.create(**self._cons_kwargs(messages, timeout=timeout))
|
||||
self._update_costs(rsp.usage)
|
||||
return rsp
|
||||
|
||||
def completion(self, messages: list[dict]) -> ChatCompletion:
|
||||
return self._chat_completion(messages)
|
||||
def completion(self, messages: list[dict], timeout=3) -> ChatCompletion:
|
||||
return self._chat_completion(messages, timeout=timeout)
|
||||
|
||||
async def acompletion(self, messages: list[dict]) -> ChatCompletion:
|
||||
return await self._achat_completion(messages)
|
||||
async def acompletion(self, messages: list[dict], timeout=3) -> ChatCompletion:
|
||||
return await self._achat_completion(messages, timeout=timeout)
|
||||
|
||||
@retry(
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
|
|
@ -268,14 +171,25 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
retry=retry_if_exception_type(APIConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
|
||||
"""when streaming, print each token in place."""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
rsp = await self._achat_completion(messages)
|
||||
resp = self._achat_completion_stream(messages, timeout=timeout)
|
||||
|
||||
collected_messages = []
|
||||
async for i in resp:
|
||||
log_llm_stream(i)
|
||||
collected_messages.append(i)
|
||||
|
||||
full_reply_content = "".join(collected_messages)
|
||||
usage = self._calc_usage(messages, full_reply_content)
|
||||
self._update_costs(usage)
|
||||
return full_reply_content
|
||||
|
||||
rsp = await self._achat_completion(messages, timeout=timeout)
|
||||
return self.get_choice_text(rsp)
|
||||
|
||||
def _func_configs(self, messages: list[dict], **kwargs) -> dict:
|
||||
def _func_configs(self, messages: list[dict], timeout=3, **kwargs) -> dict:
|
||||
"""
|
||||
Note: Keep kwargs consistent with the parameters in the https://platform.openai.com/docs/api-reference/chat/create
|
||||
"""
|
||||
|
|
@ -286,17 +200,16 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
}
|
||||
kwargs.update(configs)
|
||||
|
||||
return self._cons_kwargs(messages, **kwargs)
|
||||
return self._cons_kwargs(messages=messages, timeout=timeout, **kwargs)
|
||||
|
||||
def _chat_completion_function(self, messages: list[dict], **kwargs) -> ChatCompletion:
|
||||
def _chat_completion_function(self, messages: list[dict], timeout=3, **kwargs) -> ChatCompletion:
|
||||
rsp: ChatCompletion = self.client.chat.completions.create(**self._func_configs(messages, **kwargs))
|
||||
self._update_costs(rsp.usage)
|
||||
return rsp
|
||||
|
||||
async def _achat_completion_function(self, messages: list[dict], **chat_configs) -> ChatCompletion:
|
||||
rsp: ChatCompletion = await self.async_client.chat.completions.create(
|
||||
**self._func_configs(messages, **chat_configs)
|
||||
)
|
||||
async def _achat_completion_function(self, messages: list[dict], timeout=3, **chat_configs) -> ChatCompletion:
|
||||
kwargs = self._func_configs(messages=messages, timeout=timeout, **chat_configs)
|
||||
rsp: ChatCompletion = await self.async_client.chat.completions.create(**kwargs)
|
||||
self._update_costs(rsp.usage)
|
||||
return rsp
|
||||
|
||||
|
|
@ -349,8 +262,12 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
>>> rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
"""
|
||||
messages = self._process_message(messages)
|
||||
rsp = await self._achat_completion_function(messages, **kwargs)
|
||||
return self.get_choice_function_arguments(rsp)
|
||||
try:
|
||||
rsp = await self._achat_completion_function(messages, **kwargs)
|
||||
return self.get_choice_function_arguments(rsp)
|
||||
except openai.BadRequestError as e:
|
||||
logger.error(f"API TYPE:{CONFIG.OPENAI_API_TYPE}, err:{e}")
|
||||
raise e
|
||||
|
||||
def get_choice_function_arguments(self, rsp: ChatCompletion) -> dict:
|
||||
"""Required to provide the first function arguments of choice.
|
||||
|
|
@ -380,7 +297,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
|
||||
return usage
|
||||
|
||||
async def acompletion_batch(self, batch: list[list[dict]]) -> list[ChatCompletion]:
|
||||
async def acompletion_batch(self, batch: list[list[dict]], timeout=3) -> list[ChatCompletion]:
|
||||
"""Return full JSON"""
|
||||
split_batches = self.split_batches(batch)
|
||||
all_results = []
|
||||
|
|
@ -389,16 +306,16 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
logger.info(small_batch)
|
||||
await self.wait_if_needed(len(small_batch))
|
||||
|
||||
future = [self.acompletion(prompt) for prompt in small_batch]
|
||||
future = [self.acompletion(prompt, timeout=timeout) for prompt in small_batch]
|
||||
results = await asyncio.gather(*future)
|
||||
logger.info(results)
|
||||
all_results.extend(results)
|
||||
|
||||
return all_results
|
||||
|
||||
async def acompletion_batch_text(self, batch: list[list[dict]]) -> list[str]:
|
||||
async def acompletion_batch_text(self, batch: list[list[dict]], timeout=3) -> list[str]:
|
||||
"""Only return plain text"""
|
||||
raw_results = await self.acompletion_batch(batch)
|
||||
raw_results = await self.acompletion_batch(batch, timeout=timeout)
|
||||
results = []
|
||||
for idx, raw_result in enumerate(raw_results, start=1):
|
||||
result = self.get_choice_text(raw_result)
|
||||
|
|
@ -409,18 +326,101 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
def _update_costs(self, usage: CompletionUsage):
|
||||
if CONFIG.calc_usage and usage:
|
||||
try:
|
||||
self._cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
CONFIG.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error("updating costs failed!", e)
|
||||
logger.error(f"updating costs failed!, exp: {e}")
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
return self._cost_manager.get_costs()
|
||||
return CONFIG.cost_manager.get_costs()
|
||||
|
||||
def get_max_tokens(self, messages: list[dict]):
|
||||
if not self.auto_max_tokens:
|
||||
return CONFIG.max_tokens_rsp
|
||||
return get_max_completion_tokens(messages, self.model, CONFIG.max_tokens_rsp)
|
||||
|
||||
def moderation(self, content: Union[str, list[str]]):
|
||||
return self.client.moderations.create(input=content)
|
||||
|
||||
@handle_exception
|
||||
async def amoderation(self, content: Union[str, list[str]]):
|
||||
return await self.async_client.moderations.create(input=content)
|
||||
|
||||
async def close(self):
|
||||
"""Close connection"""
|
||||
if self.client:
|
||||
self.client.close()
|
||||
self.client = None
|
||||
if self.async_client:
|
||||
await self.async_client.close()
|
||||
self.async_client = None
|
||||
|
||||
async def summarize(self, text: str, max_words=200, keep_language: bool = False, limit: int = -1) -> str:
|
||||
max_token_count = DEFAULT_MAX_TOKENS
|
||||
max_count = 100
|
||||
text_length = len(text)
|
||||
if limit > 0 and text_length < limit:
|
||||
return text
|
||||
summary = ""
|
||||
while max_count > 0:
|
||||
if text_length < max_token_count:
|
||||
summary = await self._get_summary(text=text, max_words=max_words, keep_language=keep_language)
|
||||
break
|
||||
|
||||
padding_size = 20 if max_token_count > 20 else 0
|
||||
text_windows = self.split_texts(text, window_size=max_token_count - padding_size)
|
||||
part_max_words = min(int(max_words / len(text_windows)) + 1, 100)
|
||||
summaries = []
|
||||
for ws in text_windows:
|
||||
response = await self._get_summary(text=ws, max_words=part_max_words, keep_language=keep_language)
|
||||
summaries.append(response)
|
||||
if len(summaries) == 1:
|
||||
summary = summaries[0]
|
||||
break
|
||||
|
||||
# Merged and retry
|
||||
text = "\n".join(summaries)
|
||||
text_length = len(text)
|
||||
|
||||
max_count -= 1 # safeguard
|
||||
return summary
|
||||
|
||||
async def _get_summary(self, text: str, max_words=20, keep_language: bool = False):
|
||||
"""Generate text summary"""
|
||||
if len(text) < max_words:
|
||||
return text
|
||||
if keep_language:
|
||||
command = f".Translate the above content into a summary of less than {max_words} words in language of the content strictly."
|
||||
else:
|
||||
command = f"Translate the above content into a summary of less than {max_words} words."
|
||||
msg = text + "\n\n" + command
|
||||
logger.debug(f"summary ask:{msg}")
|
||||
response = await self.aask(msg=msg, system_msgs=[])
|
||||
logger.debug(f"summary rsp: {response}")
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def split_texts(text: str, window_size) -> List[str]:
|
||||
"""Splitting long text into sliding windows text"""
|
||||
if window_size <= 0:
|
||||
window_size = DEFAULT_TOKEN_SIZE
|
||||
total_len = len(text)
|
||||
if total_len <= window_size:
|
||||
return [text]
|
||||
|
||||
padding_size = 20 if window_size > 20 else 0
|
||||
windows = []
|
||||
idx = 0
|
||||
data_len = window_size - padding_size
|
||||
while idx < total_len:
|
||||
if window_size + idx > total_len: # 不足一个滑窗
|
||||
windows.append(text[idx:])
|
||||
break
|
||||
# 每个窗口少算padding_size自然就可实现滑窗功能, 比如: [1, 2, 3, 4, 5, 6, 7, ....]
|
||||
# window_size=3, padding_size=1:
|
||||
# [1, 2, 3], [3, 4, 5], [5, 6, 7], ....
|
||||
# idx=2, | idx=5 | idx=8 | ...
|
||||
w = text[idx : idx + window_size]
|
||||
windows.append(w)
|
||||
idx += data_len
|
||||
|
||||
return windows
|
||||
|
|
|
|||
|
|
@ -26,16 +26,19 @@ from metagpt.provider.llm_provider_registry import register_provider
|
|||
|
||||
|
||||
@register_provider(LLMProviderEnum.SPARK)
|
||||
class SparkAPI(BaseGPTAPI):
|
||||
class SparkGPTAPI(BaseGPTAPI):
|
||||
def __init__(self):
|
||||
logger.warning("当前方法无法支持异步运行。当你使用acompletion时,并不能并行访问。")
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def ask(self, msg: str) -> str:
|
||||
message = [self._default_system_msg(), self._user_msg(msg)]
|
||||
rsp = self.completion(message)
|
||||
return rsp
|
||||
|
||||
async def aask(self, msg: str, system_msgs: Optional[list[str]] = None) -> str:
|
||||
async def aask(self, msg: str, system_msgs: Optional[list[str]] = None, stream: bool = True) -> str:
|
||||
if system_msgs:
|
||||
message = self._system_msgs(system_msgs) + [self._user_msg(msg)]
|
||||
else:
|
||||
|
|
@ -47,7 +50,9 @@ class SparkAPI(BaseGPTAPI):
|
|||
def get_choice_text(self, rsp: dict) -> str:
|
||||
return rsp["payload"]["choices"]["text"][-1]["content"]
|
||||
|
||||
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
|
||||
async def acompletion_text(
|
||||
self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3
|
||||
) -> str:
|
||||
# 不支持
|
||||
logger.error("该功能禁用。")
|
||||
w = GetMessageFromWeb(messages)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
import json
|
||||
from enum import Enum
|
||||
|
||||
import openai
|
||||
import zhipuai
|
||||
from requests import ConnectionError
|
||||
from tenacity import (
|
||||
|
|
@ -20,7 +19,7 @@ from metagpt.config import CONFIG, LLMProviderEnum
|
|||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import CostManager, log_and_reraise
|
||||
from metagpt.provider.openai_api import log_and_reraise
|
||||
from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI
|
||||
|
||||
|
||||
|
|
@ -44,12 +43,12 @@ class ZhiPuAIGPTAPI(BaseGPTAPI):
|
|||
self.__init_zhipuai(CONFIG)
|
||||
self.llm = ZhiPuModelAPI
|
||||
self.model = "chatglm_turbo" # so far only one model, just use it
|
||||
self._cost_manager = CostManager()
|
||||
|
||||
def __init_zhipuai(self, config: CONFIG):
|
||||
assert config.zhipuai_api_key
|
||||
zhipuai.api_key = config.zhipuai_api_key
|
||||
openai.api_key = zhipuai.api_key # due to use openai sdk, set the api_key but it will't be used.
|
||||
# due to use openai sdk, set the api_key but it will't be used.
|
||||
# openai.api_key = zhipuai.api_key # due to use openai sdk, set the api_key but it will't be used.
|
||||
|
||||
def _const_kwargs(self, messages: list[dict]) -> dict:
|
||||
kwargs = {"model": self.model, "prompt": messages, "temperature": 0.3}
|
||||
|
|
@ -61,32 +60,35 @@ class ZhiPuAIGPTAPI(BaseGPTAPI):
|
|||
try:
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0))
|
||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
CONFIG.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"zhipuai updats costs failed! exp: {e}")
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def get_choice_text(self, resp: dict) -> str:
|
||||
"""get the first text of choice from llm response"""
|
||||
assist_msg = resp.get("data", {}).get("choices", [{"role": "error"}])[-1]
|
||||
assert assist_msg["role"] == "assistant"
|
||||
return assist_msg.get("content")
|
||||
|
||||
def completion(self, messages: list[dict]) -> dict:
|
||||
def completion(self, messages: list[dict], timeout=3) -> dict:
|
||||
resp = self.llm.invoke(**self._const_kwargs(messages))
|
||||
usage = resp.get("data").get("usage")
|
||||
self._update_costs(usage)
|
||||
return resp
|
||||
|
||||
async def _achat_completion(self, messages: list[dict]) -> dict:
|
||||
async def _achat_completion(self, messages: list[dict], timeout=3) -> dict:
|
||||
resp = await self.llm.ainvoke(**self._const_kwargs(messages))
|
||||
usage = resp.get("data").get("usage")
|
||||
self._update_costs(usage)
|
||||
return resp
|
||||
|
||||
async def acompletion(self, messages: list[dict]) -> dict:
|
||||
return await self._achat_completion(messages)
|
||||
async def acompletion(self, messages: list[dict], timeout=3) -> dict:
|
||||
return await self._achat_completion(messages, timeout=timeout)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict]) -> str:
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
|
||||
response = await self.llm.asse_invoke(**self._const_kwargs(messages))
|
||||
collected_content = []
|
||||
usage = {}
|
||||
|
|
@ -129,7 +131,7 @@ class ZhiPuAIGPTAPI(BaseGPTAPI):
|
|||
retry=retry_if_exception_type(ConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str:
|
||||
"""response in async with stream or non-stream mode"""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
|
|
|
|||
|
|
@ -5,19 +5,49 @@
|
|||
@Author : alexanderwu
|
||||
@File : repo_parser.py
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import json
|
||||
import re
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import aiofiles
|
||||
import pandas as pd
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.common import any_to_str
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
|
||||
|
||||
class RepoFileInfo(BaseModel):
|
||||
file: str
|
||||
classes: List = Field(default_factory=list)
|
||||
functions: List = Field(default_factory=list)
|
||||
globals: List = Field(default_factory=list)
|
||||
page_info: List = Field(default_factory=list)
|
||||
|
||||
|
||||
class CodeBlockInfo(BaseModel):
|
||||
lineno: int
|
||||
end_lineno: int
|
||||
type_name: str
|
||||
tokens: List = Field(default_factory=list)
|
||||
properties: Dict = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ClassInfo(BaseModel):
|
||||
name: str
|
||||
package: Optional[str] = None
|
||||
attributes: Dict[str, str] = Field(default_factory=dict)
|
||||
methods: Dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class RepoParser(BaseModel):
|
||||
base_directory: Path = Field(default=None)
|
||||
|
||||
|
|
@ -27,31 +57,32 @@ class RepoParser(BaseModel):
|
|||
"""Parse a Python file in the repository."""
|
||||
return ast.parse(file_path.read_text()).body
|
||||
|
||||
def extract_class_and_function_info(self, tree, file_path):
|
||||
def extract_class_and_function_info(self, tree, file_path) -> RepoFileInfo:
|
||||
"""Extract class, function, and global variable information from the AST."""
|
||||
file_info = {
|
||||
"file": str(file_path.relative_to(self.base_directory)),
|
||||
"classes": [],
|
||||
"functions": [],
|
||||
"globals": [],
|
||||
}
|
||||
|
||||
file_info = RepoFileInfo(file=str(file_path.relative_to(self.base_directory)))
|
||||
for node in tree:
|
||||
info = RepoParser.node_to_str(node)
|
||||
file_info.page_info.append(info)
|
||||
if isinstance(node, ast.ClassDef):
|
||||
class_methods = [m.name for m in node.body if is_func(m)]
|
||||
file_info["classes"].append({"name": node.name, "methods": class_methods})
|
||||
file_info.classes.append({"name": node.name, "methods": class_methods})
|
||||
elif is_func(node):
|
||||
file_info["functions"].append(node.name)
|
||||
file_info.functions.append(node.name)
|
||||
elif isinstance(node, (ast.Assign, ast.AnnAssign)):
|
||||
for target in node.targets if isinstance(node, ast.Assign) else [node.target]:
|
||||
if isinstance(target, ast.Name):
|
||||
file_info["globals"].append(target.id)
|
||||
file_info.globals.append(target.id)
|
||||
return file_info
|
||||
|
||||
def generate_symbols(self):
|
||||
def generate_symbols(self) -> List[RepoFileInfo]:
|
||||
files_classes = []
|
||||
directory = self.base_directory
|
||||
for path in directory.rglob("*.py"):
|
||||
|
||||
matching_files = []
|
||||
extensions = ["*.py", "*.js"]
|
||||
for ext in extensions:
|
||||
matching_files += directory.rglob(ext)
|
||||
for path in matching_files:
|
||||
tree = self._parse_file(path)
|
||||
file_info = self.extract_class_and_function_info(tree, path)
|
||||
files_classes.append(file_info)
|
||||
|
|
@ -79,6 +110,215 @@ class RepoParser(BaseModel):
|
|||
elif mode == "csv":
|
||||
self.generate_dataframe_structure(output_path)
|
||||
|
||||
@staticmethod
|
||||
def node_to_str(node) -> (int, int, str, str | Tuple):
|
||||
if any_to_str(node) == any_to_str(ast.Expr):
|
||||
return CodeBlockInfo(
|
||||
lineno=node.lineno,
|
||||
end_lineno=node.end_lineno,
|
||||
type_name=any_to_str(node),
|
||||
tokens=RepoParser._parse_expr(node),
|
||||
)
|
||||
mappings = {
|
||||
any_to_str(ast.Import): lambda x: [RepoParser._parse_name(n) for n in x.names],
|
||||
any_to_str(ast.Assign): RepoParser._parse_assign,
|
||||
any_to_str(ast.ClassDef): lambda x: x.name,
|
||||
any_to_str(ast.FunctionDef): lambda x: x.name,
|
||||
any_to_str(ast.ImportFrom): lambda x: {
|
||||
"module": x.module,
|
||||
"names": [RepoParser._parse_name(n) for n in x.names],
|
||||
},
|
||||
any_to_str(ast.If): RepoParser._parse_if,
|
||||
any_to_str(ast.AsyncFunctionDef): lambda x: x.name,
|
||||
}
|
||||
func = mappings.get(any_to_str(node))
|
||||
if func:
|
||||
code_block = CodeBlockInfo(lineno=node.lineno, end_lineno=node.end_lineno, type_name=any_to_str(node))
|
||||
val = func(node)
|
||||
if isinstance(val, dict):
|
||||
code_block.properties = val
|
||||
elif isinstance(val, list):
|
||||
code_block.tokens = val
|
||||
elif isinstance(val, str):
|
||||
code_block.tokens = [val]
|
||||
else:
|
||||
raise NotImplementedError(f"Not implement:{val}")
|
||||
return code_block
|
||||
raise NotImplementedError(f"Not implement code block:{node.lineno}, {node.end_lineno}, {any_to_str(node)}")
|
||||
|
||||
@staticmethod
|
||||
def _parse_expr(node) -> List:
|
||||
funcs = {
|
||||
any_to_str(ast.Constant): lambda x: [any_to_str(x.value), RepoParser._parse_variable(x.value)],
|
||||
any_to_str(ast.Call): lambda x: [any_to_str(x.value), RepoParser._parse_variable(x.value.func)],
|
||||
}
|
||||
func = funcs.get(any_to_str(node.value))
|
||||
if func:
|
||||
return func(node)
|
||||
raise NotImplementedError(f"Not implement: {node.value}")
|
||||
|
||||
@staticmethod
|
||||
def _parse_name(n):
|
||||
if n.asname:
|
||||
return f"{n.name} as {n.asname}"
|
||||
return n.name
|
||||
|
||||
@staticmethod
|
||||
def _parse_if(n):
|
||||
tokens = [RepoParser._parse_variable(n.test.left)]
|
||||
for item in n.test.comparators:
|
||||
tokens.append(RepoParser._parse_variable(item))
|
||||
return tokens
|
||||
|
||||
@staticmethod
|
||||
def _parse_variable(node):
|
||||
funcs = {
|
||||
any_to_str(ast.Constant): lambda x: x.value,
|
||||
any_to_str(ast.Name): lambda x: x.id,
|
||||
any_to_str(ast.Attribute): lambda x: f"{x.value.id}.{x.attr}",
|
||||
}
|
||||
func = funcs.get(any_to_str(node))
|
||||
if not func:
|
||||
raise NotImplementedError(f"Not implement:{node}")
|
||||
return func(node)
|
||||
|
||||
@staticmethod
|
||||
def _parse_assign(node):
|
||||
return [RepoParser._parse_variable(t) for t in node.targets]
|
||||
|
||||
async def rebuild_class_views(self, path: str | Path = None):
|
||||
if not path:
|
||||
path = self.base_directory
|
||||
path = Path(path)
|
||||
if not path.exists():
|
||||
return
|
||||
command = f"pyreverse {str(path)} -o dot"
|
||||
result = subprocess.run(command, shell=True, check=True, cwd=str(path))
|
||||
if result.returncode != 0:
|
||||
raise ValueError(f"{result}")
|
||||
class_view_pathname = path / "classes.dot"
|
||||
class_views = await self._parse_classes(class_view_pathname)
|
||||
packages_pathname = path / "packages.dot"
|
||||
class_views = RepoParser._repair_namespaces(class_views=class_views, path=path)
|
||||
class_view_pathname.unlink(missing_ok=True)
|
||||
packages_pathname.unlink(missing_ok=True)
|
||||
return class_views
|
||||
|
||||
async def _parse_classes(self, class_view_pathname):
|
||||
class_views = []
|
||||
if not class_view_pathname.exists():
|
||||
return class_views
|
||||
async with aiofiles.open(str(class_view_pathname), mode="r") as reader:
|
||||
lines = await reader.readlines()
|
||||
for line in lines:
|
||||
package_name, info = RepoParser._split_class_line(line)
|
||||
if not package_name:
|
||||
continue
|
||||
class_name, members, functions = re.split(r"(?<!\\)\|", info)
|
||||
class_info = ClassInfo(name=class_name)
|
||||
class_info.package = package_name
|
||||
for m in members.split("\n"):
|
||||
if not m:
|
||||
continue
|
||||
member_name = m.split(":", 1)[0].strip() if ":" in m else m.strip()
|
||||
class_info.attributes[member_name] = m
|
||||
for f in functions.split("\n"):
|
||||
if not f:
|
||||
continue
|
||||
function_name, _ = f.split("(", 1)
|
||||
class_info.methods[function_name] = f
|
||||
class_views.append(class_info)
|
||||
return class_views
|
||||
|
||||
@staticmethod
|
||||
def _split_class_line(line):
|
||||
part_splitor = '" ['
|
||||
if part_splitor not in line:
|
||||
return None, None
|
||||
ix = line.find(part_splitor)
|
||||
class_name = line[0:ix].replace('"', "")
|
||||
left = line[ix:]
|
||||
begin_flag = "label=<{"
|
||||
end_flag = "}>"
|
||||
if begin_flag not in left or end_flag not in left:
|
||||
return None, None
|
||||
bix = left.find(begin_flag)
|
||||
eix = left.rfind(end_flag)
|
||||
info = left[bix + len(begin_flag) : eix]
|
||||
info = re.sub(r"<br[^>]*>", "\n", info)
|
||||
return class_name, info
|
||||
|
||||
@staticmethod
|
||||
def _create_path_mapping(path: str | Path) -> Dict[str, str]:
|
||||
mappings = {
|
||||
str(path).replace("/", "."): str(path),
|
||||
}
|
||||
files = []
|
||||
try:
|
||||
directory_path = Path(path)
|
||||
if not directory_path.exists():
|
||||
return mappings
|
||||
for file_path in directory_path.iterdir():
|
||||
if file_path.is_file():
|
||||
files.append(str(file_path))
|
||||
else:
|
||||
subfolder_files = RepoParser._create_path_mapping(path=file_path)
|
||||
mappings.update(subfolder_files)
|
||||
except Exception as e:
|
||||
logger.error(f"Error: {e}")
|
||||
for f in files:
|
||||
mappings[str(Path(f).with_suffix("")).replace("/", ".")] = str(f)
|
||||
|
||||
return mappings
|
||||
|
||||
@staticmethod
|
||||
def _repair_namespaces(class_views: List[ClassInfo], path: str | Path) -> List[ClassInfo]:
|
||||
if not class_views:
|
||||
return []
|
||||
c = class_views[0]
|
||||
full_key = str(path).lstrip("/").replace("/", ".")
|
||||
root_namespace = RepoParser._find_root(full_key, c.package)
|
||||
root_path = root_namespace.replace(".", "/")
|
||||
|
||||
mappings = RepoParser._create_path_mapping(path=path)
|
||||
new_mappings = {}
|
||||
ix_root_namespace = len(root_namespace)
|
||||
ix_root_path = len(root_path)
|
||||
for k, v in mappings.items():
|
||||
nk = k[ix_root_namespace:]
|
||||
nv = v[ix_root_path:]
|
||||
new_mappings[nk] = nv
|
||||
|
||||
for c in class_views:
|
||||
c.package = RepoParser._repair_ns(c.package, new_mappings)
|
||||
return class_views
|
||||
|
||||
@staticmethod
|
||||
def _repair_ns(package, mappings):
|
||||
file_ns = package
|
||||
while file_ns != "":
|
||||
if file_ns not in mappings:
|
||||
ix = file_ns.rfind(".")
|
||||
file_ns = file_ns[0:ix]
|
||||
continue
|
||||
break
|
||||
internal_ns = package[ix + 1 :]
|
||||
ns = mappings[file_ns] + ":" + internal_ns.replace(".", ":")
|
||||
return ns
|
||||
|
||||
@staticmethod
|
||||
def _find_root(full_key, package) -> str:
|
||||
left = full_key
|
||||
while left != "":
|
||||
if left in package:
|
||||
break
|
||||
if "." not in left:
|
||||
break
|
||||
ix = left.find(".")
|
||||
left = left[ix + 1 :]
|
||||
ix = full_key.rfind(left)
|
||||
return "." + full_key[0:ix]
|
||||
|
||||
|
||||
def is_func(node):
|
||||
return isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
|
||||
|
|
|
|||
143
metagpt/roles/assistant.py
Normal file
143
metagpt/roles/assistant.py
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/8/7
|
||||
@Author : mashenquan
|
||||
@File : assistant.py
|
||||
@Desc : I am attempting to incorporate certain symbol concepts from UML into MetaGPT, enabling it to have the
|
||||
ability to freely construct flows through symbol concatenation. Simultaneously, I am also striving to
|
||||
make these symbols configurable and standardized, making the process of building flows more convenient.
|
||||
For more about `fork` node in activity diagrams, see: `https://www.uml-diagrams.org/activity-diagrams.html`
|
||||
This file defines a `fork` style meta role capable of generating arbitrary roles at runtime based on a
|
||||
configuration file.
|
||||
@Modified By: mashenquan, 2023/8/22. A definition has been provided for the return value of _think: returning false
|
||||
indicates that further reasoning cannot continue.
|
||||
|
||||
"""
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions.skill_action import ArgumentsParingAction, SkillAction
|
||||
from metagpt.actions.talk_action import TalkAction
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.learn.skill_loader import SkillsDeclaration
|
||||
from metagpt.logs import logger
|
||||
from metagpt.memory.brain_memory import BrainMemory
|
||||
from metagpt.roles import Role
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
class MessageType(Enum):
|
||||
Talk = "TALK"
|
||||
Skill = "SKILL"
|
||||
|
||||
|
||||
class Assistant(Role):
|
||||
"""Assistant for solving common issues."""
|
||||
|
||||
name: str = "Lily"
|
||||
profile: str = "An assistant"
|
||||
goal: str = "Help to solve problem"
|
||||
constraints: str = "Talk in {language}"
|
||||
desc: str = ""
|
||||
memory: BrainMemory = Field(default_factory=BrainMemory)
|
||||
skills: Optional[SkillsDeclaration] = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.constraints = self.constraints.format(language=kwargs.get("language") or CONFIG.language or "Chinese")
|
||||
|
||||
async def think(self) -> bool:
|
||||
"""Everything will be done part by part."""
|
||||
last_talk = await self.refine_memory()
|
||||
if not last_talk:
|
||||
return False
|
||||
if not self.skills:
|
||||
skill_path = Path(CONFIG.SKILL_PATH) if CONFIG.SKILL_PATH else None
|
||||
self.skills = await SkillsDeclaration.load(skill_yaml_file_name=skill_path)
|
||||
|
||||
prompt = ""
|
||||
skills = self.skills.get_skill_list()
|
||||
for desc, name in skills.items():
|
||||
prompt += f"If the text explicitly want you to {desc}, return `[SKILL]: {name}` brief and clear. For instance: [SKILL]: {name}\n"
|
||||
prompt += 'Otherwise, return `[TALK]: {talk}` brief and clear. For instance: if {talk} is "xxxx" return [TALK]: xxxx\n\n'
|
||||
prompt += f"Now what specific action is explicitly mentioned in the text: {last_talk}\n"
|
||||
rsp = await self._llm.aask(prompt, [])
|
||||
logger.info(f"THINK: {prompt}\n, THINK RESULT: {rsp}\n")
|
||||
return await self._plan(rsp, last_talk=last_talk)
|
||||
|
||||
async def act(self) -> Message:
|
||||
result = await self._rc.todo.run()
|
||||
if not result:
|
||||
return None
|
||||
if isinstance(result, str):
|
||||
msg = Message(content=result, role="assistant", cause_by=self._rc.todo)
|
||||
elif isinstance(result, Message):
|
||||
msg = result
|
||||
else:
|
||||
msg = Message(
|
||||
content=result.content, instruct_content=result.instruct_content, cause_by=type(self._rc.todo)
|
||||
)
|
||||
self.memory.add_answer(msg)
|
||||
return msg
|
||||
|
||||
async def talk(self, text):
|
||||
self.memory.add_talk(Message(content=text))
|
||||
|
||||
async def _plan(self, rsp: str, **kwargs) -> bool:
|
||||
skill, text = BrainMemory.extract_info(input_string=rsp)
|
||||
handlers = {
|
||||
MessageType.Talk.value: self.talk_handler,
|
||||
MessageType.Skill.value: self.skill_handler,
|
||||
}
|
||||
handler = handlers.get(skill, self.talk_handler)
|
||||
return await handler(text, **kwargs)
|
||||
|
||||
async def talk_handler(self, text, **kwargs) -> bool:
|
||||
history = self.memory.history_text
|
||||
text = kwargs.get("last_talk") or text
|
||||
self._rc.todo = TalkAction(
|
||||
context=text, knowledge=self.memory.get_knowledge(), history_summary=history, llm=self._llm, **kwargs
|
||||
)
|
||||
return True
|
||||
|
||||
async def skill_handler(self, text, **kwargs) -> bool:
|
||||
last_talk = kwargs.get("last_talk")
|
||||
skill = self.skills.get_skill(text)
|
||||
if not skill:
|
||||
logger.info(f"skill not found: {text}")
|
||||
return await self.talk_handler(text=last_talk, **kwargs)
|
||||
action = ArgumentsParingAction(skill=skill, llm=self._llm, ask=last_talk, **kwargs)
|
||||
await action.run(**kwargs)
|
||||
if action.args is None:
|
||||
return await self.talk_handler(text=last_talk, **kwargs)
|
||||
self._rc.todo = SkillAction(
|
||||
skill=skill, args=action.args, llm=self._llm, name=skill.name, desc=skill.description
|
||||
)
|
||||
return True
|
||||
|
||||
async def refine_memory(self) -> str:
|
||||
last_talk = self.memory.pop_last_talk()
|
||||
if last_talk is None: # No user feedback, unsure if past conversation is finished.
|
||||
return None
|
||||
if not self.memory.is_history_available:
|
||||
return last_talk
|
||||
history_summary = await self.memory.summarize(max_words=800, keep_language=True, llm=self._llm)
|
||||
if last_talk and await self.memory.is_related(text1=last_talk, text2=history_summary, llm=self._llm):
|
||||
# Merge relevant content.
|
||||
merged = await self.memory.rewrite(sentence=last_talk, context=history_summary, llm=self._llm)
|
||||
return f"{merged} {last_talk}"
|
||||
|
||||
return last_talk
|
||||
|
||||
def get_memory(self) -> str:
|
||||
return self.memory.json()
|
||||
|
||||
def load_memory(self, jsn):
|
||||
try:
|
||||
self.memory = BrainMemory(**jsn)
|
||||
except Exception as e:
|
||||
logger.exception(f"load error:{e}, data:{jsn}")
|
||||
|
|
@ -43,7 +43,7 @@ from metagpt.schema import (
|
|||
Documents,
|
||||
Message,
|
||||
)
|
||||
from metagpt.utils.common import any_to_str, any_to_str_set
|
||||
from metagpt.utils.common import any_to_name, any_to_str, any_to_str_set
|
||||
|
||||
IS_PASS_PROMPT = """
|
||||
{context}
|
||||
|
|
@ -78,13 +78,17 @@ class Engineer(Role):
|
|||
n_borg: int = 1
|
||||
use_code_review: bool = False
|
||||
code_todos: list = []
|
||||
summarize_todos = []
|
||||
summarize_todos: list = []
|
||||
next_todo_action: str = ""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._init_actions([WriteCode])
|
||||
self._watch([WriteTasks, SummarizeCode, WriteCode, WriteCodeReview, FixBug])
|
||||
self.code_todos = []
|
||||
self.summarize_todos = []
|
||||
self.next_todo_action = any_to_name(WriteCode)
|
||||
|
||||
@staticmethod
|
||||
def _parse_tasks(task_msg: Document) -> list[str]:
|
||||
|
|
@ -128,8 +132,10 @@ class Engineer(Role):
|
|||
if self._rc.todo is None:
|
||||
return None
|
||||
if isinstance(self._rc.todo, WriteCode):
|
||||
self.next_todo_action = any_to_name(SummarizeCode)
|
||||
return await self._act_write_code()
|
||||
if isinstance(self._rc.todo, SummarizeCode):
|
||||
self.next_todo_action = any_to_name(WriteCode)
|
||||
return await self._act_summarize()
|
||||
return None
|
||||
|
||||
|
|
@ -301,3 +307,8 @@ class Engineer(Role):
|
|||
self.summarize_todos.append(SummarizeCode(context=ctx, llm=self._llm))
|
||||
if self.summarize_todos:
|
||||
self._rc.todo = self.summarize_todos[0]
|
||||
|
||||
@property
|
||||
def todo(self) -> str:
|
||||
"""AgentStore uses this attribute to display to the user what actions the current role should take."""
|
||||
return self.next_todo_action
|
||||
|
|
|
|||
|
|
@ -7,11 +7,11 @@
|
|||
@Modified By: mashenquan, 2023/11/27. Add `PrepareDocuments` action according to Section 2.2.3.5.1 of RFC 135.
|
||||
"""
|
||||
|
||||
|
||||
from metagpt.actions import UserRequirement, WritePRD
|
||||
from metagpt.actions.prepare_documents import PrepareDocuments
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.roles.role import Role
|
||||
from metagpt.utils.common import any_to_name
|
||||
|
||||
|
||||
class ProductManager(Role):
|
||||
|
|
@ -29,20 +29,28 @@ class ProductManager(Role):
|
|||
profile: str = "Product Manager"
|
||||
goal: str = "efficiently create a successful product that meets market demands and user expectations"
|
||||
constraints: str = "utilize the same language as the user requirements for seamless communication"
|
||||
todo_action: str = ""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._init_actions([PrepareDocuments, WritePRD])
|
||||
self._watch([UserRequirement, PrepareDocuments])
|
||||
self.todo_action = any_to_name(PrepareDocuments)
|
||||
|
||||
async def _think(self) -> None:
|
||||
async def _think(self) -> bool:
|
||||
"""Decide what to do"""
|
||||
if CONFIG.git_repo:
|
||||
self._set_state(1)
|
||||
else:
|
||||
self._set_state(0)
|
||||
return self._rc.todo
|
||||
self.todo_action = any_to_name(WritePRD)
|
||||
return bool(self._rc.todo)
|
||||
|
||||
async def _observe(self, ignore_memory=False) -> int:
|
||||
return await super()._observe(ignore_memory=True)
|
||||
|
||||
@property
|
||||
def todo(self) -> str:
|
||||
"""AgentStore uses this attribute to display to the user what actions the current role should take."""
|
||||
return self.todo_action
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
#!/usr/bin/env python
|
||||
"""
|
||||
@Modified By: mashenquan, 2023/8/22. A definition has been provided for the return value of _think: returning false indicates that further reasoning cannot continue.
|
||||
@Modified By: mashenquan, 2023-11-1. According to Chapter 2.2.1 and 2.2.2 of RFC 116, change the data type of
|
||||
the `cause_by` value in the `Message` to a string to support the new message distribution feature.
|
||||
"""
|
||||
|
|
@ -39,6 +40,17 @@ class Researcher(Role):
|
|||
if self.language not in ("en-us", "zh-cn"):
|
||||
logger.warning(f"The language `{self.language}` has not been tested, it may not work.")
|
||||
|
||||
async def _think(self) -> bool:
|
||||
if self._rc.todo is None:
|
||||
self._set_state(0)
|
||||
return True
|
||||
|
||||
if self._rc.state + 1 < len(self._states):
|
||||
self._set_state(self._rc.state + 1)
|
||||
else:
|
||||
self._rc.todo = None
|
||||
return False
|
||||
|
||||
async def _act(self) -> Message:
|
||||
logger.info(f"{self._setting}: to do {self._rc.todo}({self._rc.todo.name})")
|
||||
todo = self._rc.todo
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
@Time : 2023/5/11 14:42
|
||||
@Author : alexanderwu
|
||||
@File : role.py
|
||||
@Modified By: mashenquan, 2023/8/22. A definition has been provided for the return value of _think: returning false indicates that further reasoning cannot continue.
|
||||
@Modified By: mashenquan, 2023-11-1. According to Chapter 2.2.1 and 2.2.2 of RFC 116:
|
||||
1. Merge the `recv` functionality into the `_observe` function. Future message reading operations will be
|
||||
consolidated within the `_observe` function.
|
||||
|
|
@ -38,6 +39,7 @@ from metagpt.memory import Memory
|
|||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.schema import Message, MessageQueue
|
||||
from metagpt.utils.common import (
|
||||
any_to_name,
|
||||
any_to_str,
|
||||
import_class,
|
||||
read_json_file,
|
||||
|
|
@ -118,7 +120,7 @@ class RoleContext(BaseModel):
|
|||
|
||||
@property
|
||||
def important_memory(self) -> list[Message]:
|
||||
"""Get the information corresponding to the watched actions"""
|
||||
"""Retrieve information corresponding to the attention action."""
|
||||
return self.memory.get_by_actions(self.watch)
|
||||
|
||||
@property
|
||||
|
|
@ -317,6 +319,9 @@ class Role(BaseModel):
|
|||
# check RoleContext after adding watch actions
|
||||
self._rc.check(self._role_id)
|
||||
|
||||
def is_watch(self, caused_by: str):
|
||||
return caused_by in self._rc.watch
|
||||
|
||||
def subscribe(self, tags: Set[str]):
|
||||
"""Used to receive Messages with certain tags from the environment. Message will be put into personal message
|
||||
buffer to be further processed in _observe. By default, a Role subscribes Messages with a tag of its own name
|
||||
|
|
@ -340,6 +345,11 @@ class Role(BaseModel):
|
|||
env.set_subscription(self, self.subscription)
|
||||
self.refresh_system_message() # add env message to system message
|
||||
|
||||
@property
|
||||
def action_count(self):
|
||||
"""Return number of action"""
|
||||
return len(self._actions)
|
||||
|
||||
def _get_prefix(self):
|
||||
"""Get the role prefix"""
|
||||
if self.desc:
|
||||
|
|
@ -356,16 +366,18 @@ class Role(BaseModel):
|
|||
prefix += env_desc
|
||||
return prefix
|
||||
|
||||
async def _think(self) -> None:
|
||||
"""Think about what to do and decide on the next action"""
|
||||
async def _think(self) -> bool:
|
||||
"""Consider what to do and decide on the next course of action. Return false if nothing can be done."""
|
||||
if len(self._actions) == 1:
|
||||
# If there is only one action, then only this one can be performed
|
||||
self._set_state(0)
|
||||
return
|
||||
|
||||
return True
|
||||
|
||||
if self.recovered and self._rc.state >= 0:
|
||||
self._set_state(self._rc.state) # action to run from recovered state
|
||||
self.recovered = False # avoid max_react_loop out of work
|
||||
return
|
||||
self.set_recovered(False) # avoid max_react_loop out of work
|
||||
return True
|
||||
|
||||
prompt = self._get_prefix()
|
||||
prompt += STATE_TEMPLATE.format(
|
||||
|
|
@ -387,6 +399,7 @@ class Role(BaseModel):
|
|||
if next_state == -1:
|
||||
logger.info(f"End actions with {next_state=}")
|
||||
self._set_state(next_state)
|
||||
return True
|
||||
|
||||
async def _act(self) -> Message:
|
||||
logger.info(f"{self._setting}: to do {self._rc.todo}({self._rc.todo.name})")
|
||||
|
|
@ -420,17 +433,17 @@ class Role(BaseModel):
|
|||
async def _observe(self, ignore_memory=False) -> int:
|
||||
"""Prepare new messages for processing from the message buffer and other sources."""
|
||||
# Read unprocessed messages from the msg buffer.
|
||||
news = self._rc.msg_buffer.pop_all()
|
||||
news = []
|
||||
if self.recovered:
|
||||
news = [self.latest_observed_msg] if self.latest_observed_msg else []
|
||||
else:
|
||||
self.latest_observed_msg = news[-1] if len(news) > 0 else None # record the latest observed msg
|
||||
|
||||
if not news:
|
||||
news = self._rc.msg_buffer.pop_all()
|
||||
# Store the read messages in your own memory to prevent duplicate processing.
|
||||
old_messages = [] if ignore_memory else self._rc.memory.get()
|
||||
self._rc.memory.add_batch(news)
|
||||
# Filter out messages of interest.
|
||||
self._rc.news = self._find_news(news, old_messages)
|
||||
self._rc.news = [n for n in news if n.cause_by in self._rc.watch and n not in old_messages]
|
||||
self.latest_observed_msg = self._rc.news[-1] if self._rc.news else None # record the latest observed msg
|
||||
|
||||
# Design Rules:
|
||||
# If you need to further categorize Message objects, you can do so using the Message.set_meta function.
|
||||
|
|
@ -440,6 +453,29 @@ class Role(BaseModel):
|
|||
logger.debug(f"{self._setting} observed: {news_text}")
|
||||
return len(self._rc.news)
|
||||
|
||||
# async def _observe(self, ignore_memory=False) -> int:
|
||||
# """Prepare new messages for processing from the message buffer and other sources."""
|
||||
# # Read unprocessed messages from the msg buffer.
|
||||
# news = self._rc.msg_buffer.pop_all()
|
||||
# if self.recovered:
|
||||
# news = [self.latest_observed_msg] if self.latest_observed_msg else []
|
||||
# else:
|
||||
# self.latest_observed_msg = news[-1] if len(news) > 0 else None # record the latest observed msg
|
||||
#
|
||||
# # Store the read messages in your own memory to prevent duplicate processing.
|
||||
# old_messages = [] if ignore_memory else self._rc.memory.get()
|
||||
# self._rc.memory.add_batch(news)
|
||||
# # Filter out messages of interest.
|
||||
# self._rc.news = self._find_news(news, old_messages)
|
||||
#
|
||||
# # Design Rules:
|
||||
# # If you need to further categorize Message objects, you can do so using the Message.set_meta function.
|
||||
# # msg_buffer is a receiving buffer, avoid adding message data and operations to msg_buffer.
|
||||
# news_text = [f"{i.role}: {i.content[:20]}..." for i in self._rc.news]
|
||||
# if news_text:
|
||||
# logger.debug(f"{self._setting} observed: {news_text}")
|
||||
# return len(self._rc.news)
|
||||
|
||||
def publish_message(self, msg):
|
||||
"""If the role belongs to env, then the role's messages will be broadcast to env"""
|
||||
if not msg:
|
||||
|
|
@ -498,23 +534,6 @@ class Role(BaseModel):
|
|||
self._set_state(state=-1) # current reaction is complete, reset state to -1 and todo back to None
|
||||
return rsp
|
||||
|
||||
# # Replaced by run()
|
||||
# def recv(self, message: Message) -> None:
|
||||
# """add message to history."""
|
||||
# # self._history += f"\n{message}"
|
||||
# # self._context = self._history
|
||||
# if message in self._rc.memory.get():
|
||||
# return
|
||||
# self._rc.memory.add(message)
|
||||
|
||||
# # Replaced by run()
|
||||
# async def handle(self, message: Message) -> Message:
|
||||
# """Receive information and reply with actions"""
|
||||
# # logger.debug(f"{self.name=}, {self.profile=}, {message.role=}")
|
||||
# self.recv(message)
|
||||
#
|
||||
# return await self._react()
|
||||
|
||||
def get_memories(self, k=0) -> list[Message]:
|
||||
"""A wrapper to return the most recent k memories of this role, return all when k=0"""
|
||||
return self._rc.memory.get(k=k)
|
||||
|
|
@ -551,3 +570,20 @@ class Role(BaseModel):
|
|||
def is_idle(self) -> bool:
|
||||
"""If true, all actions have been executed."""
|
||||
return not self._rc.news and not self._rc.todo and self._rc.msg_buffer.empty()
|
||||
|
||||
async def think(self) -> Action:
|
||||
"""The exported `think` function"""
|
||||
await self._think()
|
||||
return self._rc.todo
|
||||
|
||||
async def act(self) -> ActionOutput:
|
||||
"""The exported `act` function"""
|
||||
msg = await self._act()
|
||||
return ActionOutput(content=msg.content, instruct_content=msg.instruct_content)
|
||||
|
||||
@property
|
||||
def todo(self) -> str:
|
||||
"""AgentStore uses this attribute to display to the user what actions the current role should take."""
|
||||
if self._actions:
|
||||
return any_to_name(self._actions[0])
|
||||
return ""
|
||||
|
|
|
|||
|
|
@ -15,14 +15,15 @@ from metagpt.tools import SearchEngineType
|
|||
|
||||
|
||||
class Sales(Role):
|
||||
name: str = "Xiaomei"
|
||||
profile: str = "Retail sales guide"
|
||||
desc: str = "I am a sales guide in retail. My name is Xiaomei. I will answer some customer questions next, and I "
|
||||
"will answer questions only based on the information in the knowledge base."
|
||||
"If I feel that you can't get the answer from the reference material, then I will directly reply that"
|
||||
" I don't know, and I won't tell you that this is from the knowledge base,"
|
||||
"but pretend to be what I know. Note that each of my replies will be replied in the tone of a "
|
||||
"professional guide"
|
||||
name: str = "John Smith"
|
||||
profile: str = "Retail Sales Guide"
|
||||
desc: str = (
|
||||
"As a Retail Sales Guide, my name is John Smith. I specialize in addressing customer inquiries with "
|
||||
"expertise and precision. My responses are based solely on the information available in our knowledge"
|
||||
" base. In instances where your query extends beyond this scope, I'll honestly indicate my inability "
|
||||
"to provide an answer, rather than speculate or assume. Please note, each of my replies will be "
|
||||
"delivered with the professionalism and courtesy expected of a seasoned sales guide."
|
||||
)
|
||||
|
||||
store: Optional[BaseStore] = None
|
||||
|
||||
|
|
|
|||
118
metagpt/roles/teacher.py
Normal file
118
metagpt/roles/teacher.py
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/7/27
|
||||
@Author : mashenquan
|
||||
@File : teacher.py
|
||||
@Desc : Used by Agent Store
|
||||
@Modified By: mashenquan, 2023/8/22. A definition has been provided for the return value of _think: returning false indicates that further reasoning cannot continue.
|
||||
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
import aiofiles
|
||||
|
||||
from metagpt.actions import UserRequirement
|
||||
from metagpt.actions.write_teaching_plan import TeachingPlanBlock, WriteTeachingPlanPart
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles import Role
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import any_to_str
|
||||
|
||||
|
||||
class Teacher(Role):
|
||||
"""Support configurable teacher roles,
|
||||
with native and teaching languages being replaceable through configurations."""
|
||||
|
||||
name: str = "Lily"
|
||||
profile: str = "{teaching_language} Teacher"
|
||||
goal: str = "writing a {language} teaching plan part by part"
|
||||
constraints: str = "writing in {language}"
|
||||
desc: str = ""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.name = WriteTeachingPlanPart.format_value(self.name)
|
||||
self.profile = WriteTeachingPlanPart.format_value(self.profile)
|
||||
self.goal = WriteTeachingPlanPart.format_value(self.goal)
|
||||
self.constraints = WriteTeachingPlanPart.format_value(self.constraints)
|
||||
self.desc = WriteTeachingPlanPart.format_value(self.desc)
|
||||
|
||||
async def _think(self) -> bool:
|
||||
"""Everything will be done part by part."""
|
||||
if not self._actions:
|
||||
if not self._rc.news or self._rc.news[0].cause_by != any_to_str(UserRequirement):
|
||||
raise ValueError("Lesson content invalid.")
|
||||
actions = []
|
||||
print(TeachingPlanBlock.TOPICS)
|
||||
for topic in TeachingPlanBlock.TOPICS:
|
||||
act = WriteTeachingPlanPart(context=self._rc.news[0].content, topic=topic, llm=self._llm)
|
||||
actions.append(act)
|
||||
self._init_actions(actions)
|
||||
|
||||
if self._rc.todo is None:
|
||||
self._set_state(0)
|
||||
return True
|
||||
|
||||
if self._rc.state + 1 < len(self._states):
|
||||
self._set_state(self._rc.state + 1)
|
||||
return True
|
||||
|
||||
self._rc.todo = None
|
||||
return False
|
||||
|
||||
async def _react(self) -> Message:
|
||||
ret = Message(content="")
|
||||
while True:
|
||||
await self._think()
|
||||
if self._rc.todo is None:
|
||||
break
|
||||
logger.debug(f"{self._setting}: {self._rc.state=}, will do {self._rc.todo}")
|
||||
msg = await self._act()
|
||||
if ret.content != "":
|
||||
ret.content += "\n\n\n"
|
||||
ret.content += msg.content
|
||||
logger.info(ret.content)
|
||||
await self.save(ret.content)
|
||||
return ret
|
||||
|
||||
async def save(self, content):
|
||||
"""Save teaching plan"""
|
||||
filename = Teacher.new_file_name(self.course_title)
|
||||
pathname = CONFIG.workspace_path / "teaching_plan"
|
||||
pathname.mkdir(exist_ok=True)
|
||||
pathname = pathname / filename
|
||||
try:
|
||||
async with aiofiles.open(str(pathname), mode="w", encoding="utf-8") as writer:
|
||||
await writer.write(content)
|
||||
except Exception as e:
|
||||
logger.error(f"Save failed:{e}")
|
||||
logger.info(f"Save to:{pathname}")
|
||||
|
||||
@staticmethod
|
||||
def new_file_name(lesson_title, ext=".md"):
|
||||
"""Create a related file name based on `lesson_title` and `ext`."""
|
||||
# Define the special characters that need to be replaced.
|
||||
illegal_chars = r'[#@$%!*&\\/:*?"<>|\n\t \']'
|
||||
# Replace the special characters with underscores.
|
||||
filename = re.sub(illegal_chars, "_", lesson_title) + ext
|
||||
return re.sub(r"_+", "_", filename)
|
||||
|
||||
@property
|
||||
def course_title(self):
|
||||
"""Return course title of teaching plan"""
|
||||
default_title = "teaching_plan"
|
||||
for act in self._actions:
|
||||
if act.topic != TeachingPlanBlock.COURSE_TITLE:
|
||||
continue
|
||||
if act.rsp is None:
|
||||
return default_title
|
||||
title = act.rsp.lstrip("# \n")
|
||||
if "\n" in title:
|
||||
ix = title.index("\n")
|
||||
title = title[0:ix]
|
||||
return title
|
||||
|
||||
return default_title
|
||||
|
|
@ -23,7 +23,7 @@ from abc import ABC
|
|||
from asyncio import Queue, QueueEmpty, wait_for
|
||||
from json import JSONDecodeError
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Type, TypedDict, TypeVar
|
||||
from typing import Any, Dict, List, Optional, Set, Type, TypeVar
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
|
@ -46,7 +46,7 @@ from metagpt.utils.serialize import (
|
|||
)
|
||||
|
||||
|
||||
class RawMessage(TypedDict):
|
||||
class SimpleMessage(BaseModel):
|
||||
content: str
|
||||
role: str
|
||||
|
||||
|
|
@ -162,8 +162,7 @@ class Message(BaseModel):
|
|||
# prefix = '-'.join([self.role, str(self.cause_by)])
|
||||
if self.instruct_content:
|
||||
return f"{self.role}: {self.instruct_content.dict()}"
|
||||
else:
|
||||
return f"{self.role}: {self.content}"
|
||||
return f"{self.role}: {self.content}"
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
|
@ -180,8 +179,19 @@ class Message(BaseModel):
|
|||
@handle_exception(exception_type=JSONDecodeError, default_return=None)
|
||||
def load(val):
|
||||
"""Convert the json string to object."""
|
||||
i = json.loads(val)
|
||||
return Message(**i)
|
||||
|
||||
try:
|
||||
m = json.loads(val)
|
||||
id = m.get("id")
|
||||
if "id" in m:
|
||||
del m["id"]
|
||||
msg = Message(**m)
|
||||
if id:
|
||||
msg.id = id
|
||||
return msg
|
||||
except JSONDecodeError as err:
|
||||
logger.error(f"parse json failed: {val}, error:{err}")
|
||||
return None
|
||||
|
||||
|
||||
class UserMessage(Message):
|
||||
|
|
|
|||
|
|
@ -90,9 +90,12 @@ class Team(BaseModel):
|
|||
CONFIG.max_budget = investment
|
||||
logger.info(f"Investment: ${investment}.")
|
||||
|
||||
def _check_balance(self):
|
||||
if CONFIG.total_cost > CONFIG.max_budget:
|
||||
raise NoMoneyException(CONFIG.total_cost, f"Insufficient funds: {CONFIG.max_budget}")
|
||||
@staticmethod
|
||||
def _check_balance():
|
||||
if CONFIG.cost_manager.total_cost > CONFIG.cost_manager.max_budget:
|
||||
raise NoMoneyException(
|
||||
CONFIG.cost_manager.total_cost, f"Insufficient funds: {CONFIG.cost_manager.max_budget}"
|
||||
)
|
||||
|
||||
def run_project(self, idea, send_to: str = ""):
|
||||
"""Run a project from publishing user requirement."""
|
||||
|
|
@ -100,7 +103,8 @@ class Team(BaseModel):
|
|||
|
||||
# Human requirement.
|
||||
self.env.publish_message(
|
||||
Message(role="Human", content=idea, cause_by=UserRequirement, send_to=send_to or MESSAGE_ROUTE_TO_ALL)
|
||||
Message(role="Human", content=idea, cause_by=UserRequirement, send_to=send_to or MESSAGE_ROUTE_TO_ALL),
|
||||
peekable=False,
|
||||
)
|
||||
|
||||
def start_project(self, idea, send_to: str = ""):
|
||||
|
|
@ -120,7 +124,7 @@ class Team(BaseModel):
|
|||
logger.info(self.json(ensure_ascii=False))
|
||||
|
||||
@serialize_decorator
|
||||
async def run(self, n_round=3, idea="", send_to=""):
|
||||
async def run(self, n_round=3, idea="", send_to="", auto_archive=True):
|
||||
"""Run company until target round or no money"""
|
||||
if idea:
|
||||
self.run_project(idea=idea, send_to=send_to)
|
||||
|
|
@ -132,6 +136,5 @@ class Team(BaseModel):
|
|||
self._check_balance()
|
||||
|
||||
await self.env.run()
|
||||
if CONFIG.git_repo:
|
||||
CONFIG.git_repo.archive()
|
||||
self.env.archive(auto_archive)
|
||||
return self.env.history
|
||||
|
|
|
|||
|
|
@ -22,3 +22,8 @@ class WebBrowserEngineType(Enum):
|
|||
PLAYWRIGHT = "playwright"
|
||||
SELENIUM = "selenium"
|
||||
CUSTOM = "custom"
|
||||
|
||||
@classmethod
|
||||
def __missing__(cls, key):
|
||||
"""Default type conversion"""
|
||||
return cls.CUSTOM
|
||||
|
|
|
|||
|
|
@ -4,39 +4,110 @@
|
|||
@Time : 2023/6/9 22:22
|
||||
@Author : Leo Xiao
|
||||
@File : azure_tts.py
|
||||
@Modified by: mashenquan, 2023/8/17. Azure TTS OAS3 api, which provides text-to-speech functionality
|
||||
"""
|
||||
import asyncio
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
import aiofiles
|
||||
from azure.cognitiveservices.speech import AudioConfig, SpeechConfig, SpeechSynthesizer
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.config import CONFIG, Config
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
class AzureTTS:
|
||||
"""https://learn.microsoft.com/zh-cn/azure/cognitive-services/speech-service/language-support?tabs=tts#voice-styles-and-roles"""
|
||||
"""Azure Text-to-Speech"""
|
||||
|
||||
@classmethod
|
||||
def synthesize_speech(cls, lang, voice, role, text, output_file):
|
||||
subscription_key = CONFIG.get("AZURE_TTS_SUBSCRIPTION_KEY")
|
||||
region = CONFIG.get("AZURE_TTS_REGION")
|
||||
speech_config = SpeechConfig(subscription=subscription_key, region=region)
|
||||
def __init__(self, subscription_key, region):
|
||||
"""
|
||||
:param subscription_key: key is used to access your Azure AI service API, see: `https://portal.azure.com/` > `Resource Management` > `Keys and Endpoint`
|
||||
:param region: This is the location (or region) of your resource. You may need to use this field when making calls to this API.
|
||||
"""
|
||||
self.subscription_key = subscription_key if subscription_key else CONFIG.AZURE_TTS_SUBSCRIPTION_KEY
|
||||
self.region = region if region else CONFIG.AZURE_TTS_REGION
|
||||
|
||||
# 参数参考:https://learn.microsoft.com/zh-cn/azure/cognitive-services/speech-service/language-support?tabs=tts#voice-styles-and-roles
|
||||
async def synthesize_speech(self, lang, voice, text, output_file):
|
||||
speech_config = SpeechConfig(subscription=self.subscription_key, region=self.region)
|
||||
speech_config.speech_synthesis_voice_name = voice
|
||||
audio_config = AudioConfig(filename=output_file)
|
||||
synthesizer = SpeechSynthesizer(speech_config=speech_config, audio_config=audio_config)
|
||||
|
||||
# if voice=="zh-CN-YunxiNeural":
|
||||
ssml_string = f"""
|
||||
<speak version='1.0' xmlns='http://www.w3.org/2001/10/synthesis' xml:lang='{lang}' xmlns:mstts='http://www.w3.org/2001/mstts'>
|
||||
<voice name='{voice}'>
|
||||
<mstts:express-as style='affectionate' role='{role}'>
|
||||
{text}
|
||||
</mstts:express-as>
|
||||
</voice>
|
||||
</speak>
|
||||
"""
|
||||
# More detail: https://learn.microsoft.com/en-us/azure/ai-services/speech-service/speech-synthesis-markup-voice
|
||||
ssml_string = (
|
||||
"<speak version='1.0' xmlns='http://www.w3.org/2001/10/synthesis' "
|
||||
f"xml:lang='{lang}' xmlns:mstts='http://www.w3.org/2001/mstts'>"
|
||||
f"<voice name='{voice}'>{text}</voice></speak>"
|
||||
)
|
||||
|
||||
synthesizer.speak_ssml_async(ssml_string).get()
|
||||
return synthesizer.speak_ssml_async(ssml_string).get()
|
||||
|
||||
@staticmethod
|
||||
def role_style_text(role, style, text):
|
||||
return f'<mstts:express-as role="{role}" style="{style}">{text}</mstts:express-as>'
|
||||
|
||||
@staticmethod
|
||||
def role_text(role, text):
|
||||
return f'<mstts:express-as role="{role}">{text}</mstts:express-as>'
|
||||
|
||||
@staticmethod
|
||||
def style_text(style, text):
|
||||
return f'<mstts:express-as style="{style}">{text}</mstts:express-as>'
|
||||
|
||||
|
||||
# Export
|
||||
async def oas3_azsure_tts(text, lang="", voice="", style="", role="", subscription_key="", region=""):
|
||||
"""Text to speech
|
||||
For more details, check out:`https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`
|
||||
|
||||
:param lang: The value can contain a language code such as en (English), or a locale such as en-US (English - United States). For more details, checkout: `https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`
|
||||
:param voice: For more details, checkout: `https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`, `https://speech.microsoft.com/portal/voicegallery`
|
||||
:param style: Speaking style to express different emotions like cheerfulness, empathy, and calm. For more details, checkout: `https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`
|
||||
:param role: With roles, the same voice can act as a different age and gender. For more details, checkout: `https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`
|
||||
:param text: The text used for voice conversion.
|
||||
:param subscription_key: key is used to access your Azure AI service API, see: `https://portal.azure.com/` > `Resource Management` > `Keys and Endpoint`
|
||||
:param region: This is the location (or region) of your resource. You may need to use this field when making calls to this API.
|
||||
:return: Returns the Base64-encoded .wav file data if successful, otherwise an empty string.
|
||||
|
||||
"""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
if not lang:
|
||||
lang = "zh-CN"
|
||||
if not voice:
|
||||
voice = "zh-CN-XiaomoNeural"
|
||||
if not role:
|
||||
role = "Girl"
|
||||
if not style:
|
||||
style = "affectionate"
|
||||
if not subscription_key:
|
||||
subscription_key = CONFIG.AZURE_TTS_SUBSCRIPTION_KEY
|
||||
if not region:
|
||||
region = CONFIG.AZURE_TTS_REGION
|
||||
|
||||
xml_value = AzureTTS.role_style_text(role=role, style=style, text=text)
|
||||
tts = AzureTTS(subscription_key=subscription_key, region=region)
|
||||
filename = Path(__file__).resolve().parent / (str(uuid4()).replace("-", "") + ".wav")
|
||||
try:
|
||||
await tts.synthesize_speech(lang=lang, voice=voice, text=xml_value, output_file=str(filename))
|
||||
async with aiofiles.open(filename, mode="rb") as reader:
|
||||
data = await reader.read()
|
||||
base64_string = base64.b64encode(data).decode("utf-8")
|
||||
filename.unlink()
|
||||
except Exception as e:
|
||||
logger.error(f"text:{text}, error:{e}")
|
||||
return ""
|
||||
|
||||
return base64_string
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
azure_tts = AzureTTS()
|
||||
azure_tts.synthesize_speech("zh-CN", "zh-CN-YunxiNeural", "Boy", "Hello, I am Kaka", "output.wav")
|
||||
Config()
|
||||
loop = asyncio.new_event_loop()
|
||||
v = loop.create_task(oas3_azsure_tts("测试,test"))
|
||||
loop.run_until_complete(v)
|
||||
print(v)
|
||||
|
|
|
|||
27
metagpt/tools/hello.py
Normal file
27
metagpt/tools/hello.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/5/2 16:03
|
||||
@Author : mashenquan
|
||||
@File : hello.py
|
||||
@Desc : Implement the OpenAPI Specification 3.0 demo and use the following command to test the HTTP service:
|
||||
|
||||
curl -X 'POST' \
|
||||
'http://localhost:8080/openapi/greeting/dave' \
|
||||
-H 'accept: text/plain' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{}'
|
||||
"""
|
||||
|
||||
import connexion
|
||||
|
||||
|
||||
# openapi implement
|
||||
async def post_greeting(name: str) -> str:
|
||||
return f"Hello {name}\n"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = connexion.AioHttpApp(__name__, specification_dir="../../.well-known/")
|
||||
app.add_api("openapi.yaml", arguments={"title": "Hello World Example"})
|
||||
app.run(port=8080)
|
||||
162
metagpt/tools/iflytek_tts.py
Normal file
162
metagpt/tools/iflytek_tts.py
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/8/17
|
||||
@Author : mashenquan
|
||||
@File : iflytek_tts.py
|
||||
@Desc : iFLYTEK TTS OAS3 api, which provides text-to-speech functionality
|
||||
"""
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from time import mktime
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode
|
||||
from wsgiref.handlers import format_date_time
|
||||
|
||||
import aiofiles
|
||||
import websockets as websockets
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
class IFlyTekTTSStatus(Enum):
|
||||
STATUS_FIRST_FRAME = 0 # The first frame
|
||||
STATUS_CONTINUE_FRAME = 1 # The intermediate frame
|
||||
STATUS_LAST_FRAME = 2 # The last frame
|
||||
|
||||
|
||||
class AudioData(BaseModel):
|
||||
audio: str
|
||||
status: int
|
||||
ced: str
|
||||
|
||||
|
||||
class IFlyTekTTSResponse(BaseModel):
|
||||
code: int
|
||||
message: str
|
||||
data: Optional[AudioData] = None
|
||||
sid: str
|
||||
|
||||
|
||||
DEFAULT_IFLYTEK_VOICE = "xiaoyan"
|
||||
|
||||
|
||||
class IFlyTekTTS(object):
|
||||
def __init__(self, app_id: str, api_key: str, api_secret: str):
|
||||
"""
|
||||
:param app_id: Application ID is used to access your iFlyTek service API, see: `https://console.xfyun.cn/services/tts`
|
||||
:param api_key: WebAPI argument, see: `https://console.xfyun.cn/services/tts`
|
||||
:param api_secret: WebAPI argument, see: `https://console.xfyun.cn/services/tts`
|
||||
"""
|
||||
self.app_id = app_id or CONFIG.IFLYTEK_APP_ID
|
||||
self.api_key = api_key or CONFIG.IFLYTEK_API_KEY
|
||||
self.api_secret = api_secret or CONFIG.API_SECRET
|
||||
|
||||
async def synthesize_speech(self, text, output_file: str, voice=DEFAULT_IFLYTEK_VOICE):
|
||||
url = self._create_url()
|
||||
data = {
|
||||
"common": {"app_id": self.app_id},
|
||||
"business": {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": voice, "tte": "utf8"},
|
||||
"data": {"status": 2, "text": str(base64.b64encode(text.encode("utf-8")), "UTF8")},
|
||||
}
|
||||
req = json.dumps(data)
|
||||
async with websockets.connect(url) as websocket:
|
||||
# send request
|
||||
await websocket.send(req)
|
||||
|
||||
# receive frames
|
||||
async with aiofiles.open(str(output_file), "w") as writer:
|
||||
while True:
|
||||
v = await websocket.recv()
|
||||
rsp = IFlyTekTTSResponse(**json.loads(v))
|
||||
if rsp.data:
|
||||
await writer.write(rsp.data.audio)
|
||||
if rsp.data.status != IFlyTekTTSStatus.STATUS_LAST_FRAME.value:
|
||||
continue
|
||||
break
|
||||
|
||||
def _create_url(self):
|
||||
"""Create request url"""
|
||||
url = "wss://tts-api.xfyun.cn/v2/tts"
|
||||
# Generate a timestamp in RFC1123 format
|
||||
now = datetime.now()
|
||||
date = format_date_time(mktime(now.timetuple()))
|
||||
|
||||
signature_origin = "host: " + "ws-api.xfyun.cn" + "\n"
|
||||
signature_origin += "date: " + date + "\n"
|
||||
signature_origin += "GET " + "/v2/tts " + "HTTP/1.1"
|
||||
# Perform HMAC-SHA256 encryption
|
||||
signature_sha = hmac.new(
|
||||
self.api_secret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256
|
||||
).digest()
|
||||
signature_sha = base64.b64encode(signature_sha).decode(encoding="utf-8")
|
||||
|
||||
authorization_origin = 'api_key="%s", algorithm="%s", headers="%s", signature="%s"' % (
|
||||
self.api_key,
|
||||
"hmac-sha256",
|
||||
"host date request-line",
|
||||
signature_sha,
|
||||
)
|
||||
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8")
|
||||
# Combine the authentication parameters of the request into a dictionary.
|
||||
v = {"authorization": authorization, "date": date, "host": "ws-api.xfyun.cn"}
|
||||
# Concatenate the authentication parameters to generate the URL.
|
||||
url = url + "?" + urlencode(v)
|
||||
return url
|
||||
|
||||
|
||||
# Export
|
||||
async def oas3_iflytek_tts(text: str, voice: str = "", app_id: str = "", api_key: str = "", api_secret: str = ""):
|
||||
"""Text to speech
|
||||
For more details, check out:`https://www.xfyun.cn/doc/tts/online_tts/API.html`
|
||||
|
||||
:param voice: Default `xiaoyan`. For more details, checkout: `https://www.xfyun.cn/doc/tts/online_tts/API.html#%E6%8E%A5%E5%8F%A3%E8%B0%83%E7%94%A8%E6%B5%81%E7%A8%8B`
|
||||
:param text: The text used for voice conversion.
|
||||
:param app_id: Application ID is used to access your iFlyTek service API, see: `https://console.xfyun.cn/services/tts`
|
||||
:param api_key: WebAPI argument, see: `https://console.xfyun.cn/services/tts`
|
||||
:param api_secret: WebAPI argument, see: `https://console.xfyun.cn/services/tts`
|
||||
:return: Returns the Base64-encoded .mp3 file data if successful, otherwise an empty string.
|
||||
|
||||
"""
|
||||
if not app_id:
|
||||
app_id = CONFIG.IFLYTEK_APP_ID
|
||||
if not api_key:
|
||||
api_key = CONFIG.IFLYTEK_API_KEY
|
||||
if not api_secret:
|
||||
api_secret = CONFIG.IFLYTEK_API_SECRET
|
||||
if not voice:
|
||||
voice = CONFIG.IFLYTEK_VOICE or DEFAULT_IFLYTEK_VOICE
|
||||
|
||||
filename = Path(__file__).parent / (uuid.uuid4().hex + ".mp3")
|
||||
try:
|
||||
tts = IFlyTekTTS(app_id=app_id, api_key=api_key, api_secret=api_secret)
|
||||
await tts.synthesize_speech(text=text, output_file=str(filename), voice=voice)
|
||||
async with aiofiles.open(str(filename), mode="r") as reader:
|
||||
base64_string = await reader.read()
|
||||
except Exception as e:
|
||||
logger.error(f"text:{text}, error:{e}")
|
||||
base64_string = ""
|
||||
finally:
|
||||
filename.unlink()
|
||||
|
||||
return base64_string
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.get_event_loop().run_until_complete(
|
||||
oas3_iflytek_tts(
|
||||
text="你好,hello",
|
||||
app_id="f7acef62",
|
||||
api_key="fda72e3aa286042a492525816a5efa08",
|
||||
api_secret="ZDk3NjdiMDBkODJlOWQ1NjRjMGI2NDY4",
|
||||
)
|
||||
)
|
||||
44
metagpt/tools/metagpt_oas3_api_svc.py
Normal file
44
metagpt/tools/metagpt_oas3_api_svc.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/8/17
|
||||
@Author : mashenquan
|
||||
@File : metagpt_oas3_api_svc.py
|
||||
@Desc : MetaGPT OpenAPI Specification 3.0 REST API service
|
||||
"""
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import connexion
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.parent)) # fix-bug: No module named 'metagpt'
|
||||
|
||||
|
||||
def oas_http_svc():
|
||||
"""Start the OAS 3.0 OpenAPI HTTP service"""
|
||||
app = connexion.AioHttpApp(__name__, specification_dir="../../.well-known/")
|
||||
app.add_api("metagpt_oas3_api.yaml")
|
||||
app.add_api("openapi.yaml")
|
||||
app.run(port=8080)
|
||||
|
||||
|
||||
async def async_main():
|
||||
"""Start the OAS 3.0 OpenAPI HTTP service in the background."""
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_in_executor(None, oas_http_svc)
|
||||
|
||||
# TODO: replace following codes:
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
print("sleep")
|
||||
|
||||
|
||||
def main():
|
||||
print("http://localhost:8080/oas3/ui/")
|
||||
oas_http_svc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# asyncio.run(async_main())
|
||||
main()
|
||||
110
metagpt/tools/metagpt_text_to_image.py
Normal file
110
metagpt/tools/metagpt_text_to_image.py
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/8/18
|
||||
@Author : mashenquan
|
||||
@File : metagpt_text_to_image.py
|
||||
@Desc : MetaGPT Text-to-Image OAS3 api, which provides text-to-image functionality.
|
||||
"""
|
||||
import asyncio
|
||||
import base64
|
||||
from typing import Dict, List
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.config import CONFIG, Config
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
class MetaGPTText2Image:
|
||||
def __init__(self, model_url):
|
||||
"""
|
||||
:param model_url: Model reset api url
|
||||
"""
|
||||
self.model_url = model_url if model_url else CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL
|
||||
|
||||
async def text_2_image(self, text, size_type="512x512"):
|
||||
"""Text to image
|
||||
|
||||
:param text: The text used for image conversion.
|
||||
:param size_type: One of ['512x512', '512x768']
|
||||
:return: The image data is returned in Base64 encoding.
|
||||
"""
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
dims = size_type.split("x")
|
||||
data = {
|
||||
"prompt": text,
|
||||
"negative_prompt": "(easynegative:0.8),black, dark,Low resolution",
|
||||
"override_settings": {"sd_model_checkpoint": "galaxytimemachinesGTM_photoV20"},
|
||||
"seed": -1,
|
||||
"batch_size": 1,
|
||||
"n_iter": 1,
|
||||
"steps": 20,
|
||||
"cfg_scale": 11,
|
||||
"width": int(dims[0]),
|
||||
"height": int(dims[1]), # 768,
|
||||
"restore_faces": False,
|
||||
"tiling": False,
|
||||
"do_not_save_samples": False,
|
||||
"do_not_save_grid": False,
|
||||
"enable_hr": False,
|
||||
"hr_scale": 2,
|
||||
"hr_upscaler": "Latent",
|
||||
"hr_second_pass_steps": 0,
|
||||
"hr_resize_x": 0,
|
||||
"hr_resize_y": 0,
|
||||
"hr_upscale_to_x": 0,
|
||||
"hr_upscale_to_y": 0,
|
||||
"truncate_x": 0,
|
||||
"truncate_y": 0,
|
||||
"applied_old_hires_behavior_to": None,
|
||||
"eta": None,
|
||||
"sampler_index": "DPM++ SDE Karras",
|
||||
"alwayson_scripts": {},
|
||||
}
|
||||
|
||||
class ImageResult(BaseModel):
|
||||
images: List
|
||||
parameters: Dict
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(self.model_url, headers=headers, json=data) as response:
|
||||
result = ImageResult(**await response.json())
|
||||
if len(result.images) == 0:
|
||||
return ""
|
||||
return result.images[0]
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"An error occurred:{e}")
|
||||
return ""
|
||||
|
||||
|
||||
# Export
|
||||
async def oas3_metagpt_text_to_image(text, size_type: str = "512x512", model_url=""):
|
||||
"""Text to image
|
||||
|
||||
:param text: The text used for image conversion.
|
||||
:param model_url: Model reset api
|
||||
:param size_type: One of ['512x512', '512x768']
|
||||
:return: The image data is returned in Base64 encoding.
|
||||
"""
|
||||
if not text:
|
||||
return ""
|
||||
if not model_url:
|
||||
model_url = CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL
|
||||
return await MetaGPTText2Image(model_url).text_2_image(text, size_type=size_type)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Config()
|
||||
loop = asyncio.new_event_loop()
|
||||
task = loop.create_task(oas3_metagpt_text_to_image("Panda emoji"))
|
||||
v = loop.run_until_complete(task)
|
||||
print(v)
|
||||
data = base64.b64decode(v)
|
||||
with open("tmp.png", mode="wb") as writer:
|
||||
writer.write(data)
|
||||
print(v)
|
||||
90
metagpt/tools/openai_text_to_embedding.py
Normal file
90
metagpt/tools/openai_text_to_embedding.py
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/8/18
|
||||
@Author : mashenquan
|
||||
@File : openai_text_to_embedding.py
|
||||
@Desc : OpenAI Text-to-Embedding OAS3 api, which provides text-to-embedding functionality.
|
||||
For more details, checkout: `https://platform.openai.com/docs/api-reference/embeddings/object`
|
||||
"""
|
||||
import asyncio
|
||||
from typing import List
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.config import CONFIG, Config
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
class Embedding(BaseModel):
|
||||
"""Represents an embedding vector returned by embedding endpoint."""
|
||||
|
||||
object: str # The object type, which is always "embedding".
|
||||
embedding: List[
|
||||
float
|
||||
] # The embedding vector, which is a list of floats. The length of vector depends on the model as listed in the embedding guide.
|
||||
index: int # The index of the embedding in the list of embeddings.
|
||||
|
||||
|
||||
class Usage(BaseModel):
|
||||
prompt_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class ResultEmbedding(BaseModel):
|
||||
object: str
|
||||
data: List[Embedding]
|
||||
model: str
|
||||
usage: Usage
|
||||
|
||||
|
||||
class OpenAIText2Embedding:
|
||||
def __init__(self, openai_api_key):
|
||||
"""
|
||||
:param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`
|
||||
"""
|
||||
self.openai_api_key = openai_api_key if openai_api_key else CONFIG.OPENAI_API_KEY
|
||||
|
||||
async def text_2_embedding(self, text, model="text-embedding-ada-002"):
|
||||
"""Text to embedding
|
||||
|
||||
:param text: The text used for embedding.
|
||||
:param model: One of ['text-embedding-ada-002'], ID of the model to use. For more details, checkout: `https://api.openai.com/v1/models`.
|
||||
:return: A json object of :class:`ResultEmbedding` class if successful, otherwise `{}`.
|
||||
"""
|
||||
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.openai_api_key}"}
|
||||
data = {"input": text, "model": model}
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post("https://api.openai.com/v1/embeddings", headers=headers, json=data) as response:
|
||||
return await response.json()
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"An error occurred:{e}")
|
||||
return {}
|
||||
|
||||
|
||||
# Export
|
||||
async def oas3_openai_text_to_embedding(text, model="text-embedding-ada-002", openai_api_key=""):
|
||||
"""Text to embedding
|
||||
|
||||
:param text: The text used for embedding.
|
||||
:param model: One of ['text-embedding-ada-002'], ID of the model to use. For more details, checkout: `https://api.openai.com/v1/models`.
|
||||
:param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`
|
||||
:return: A json object of :class:`ResultEmbedding` class if successful, otherwise `{}`.
|
||||
"""
|
||||
if not text:
|
||||
return ""
|
||||
if not openai_api_key:
|
||||
openai_api_key = CONFIG.OPENAI_API_KEY
|
||||
return await OpenAIText2Embedding(openai_api_key).text_2_embedding(text, model=model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Config()
|
||||
loop = asyncio.new_event_loop()
|
||||
task = loop.create_task(oas3_openai_text_to_embedding("Panda emoji"))
|
||||
v = loop.run_until_complete(task)
|
||||
print(v)
|
||||
86
metagpt/tools/openai_text_to_image.py
Normal file
86
metagpt/tools/openai_text_to_image.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/8/17
|
||||
@Author : mashenquan
|
||||
@File : openai_text_to_image.py
|
||||
@Desc : OpenAI Text-to-Image OAS3 api, which provides text-to-image functionality.
|
||||
"""
|
||||
import asyncio
|
||||
import base64
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
from metagpt.config import Config
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
class OpenAIText2Image:
|
||||
def __init__(self):
|
||||
"""
|
||||
:param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`
|
||||
"""
|
||||
self._llm = LLM()
|
||||
self._client = self._llm.async_client
|
||||
|
||||
def __del__(self):
|
||||
if self._llm:
|
||||
self._llm.close()
|
||||
|
||||
async def text_2_image(self, text, size_type="1024x1024"):
|
||||
"""Text to image
|
||||
|
||||
:param text: The text used for image conversion.
|
||||
:param size_type: One of ['256x256', '512x512', '1024x1024']
|
||||
:return: The image data is returned in Base64 encoding.
|
||||
"""
|
||||
try:
|
||||
result = await self._client.images.generate(prompt=text, n=1, size=size_type)
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred:{e}")
|
||||
return ""
|
||||
if result and len(result.data) > 0:
|
||||
return await OpenAIText2Image.get_image_data(result.data[0].url)
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
async def get_image_data(url):
|
||||
"""Fetch image data from a URL and encode it as Base64
|
||||
|
||||
:param url: Image url
|
||||
:return: Base64-encoded image data.
|
||||
"""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
response.raise_for_status() # 如果是 4xx 或 5xx 响应,会引发异常
|
||||
image_data = await response.read()
|
||||
base64_image = base64.b64encode(image_data).decode("utf-8")
|
||||
return base64_image
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"An error occurred:{e}")
|
||||
return ""
|
||||
|
||||
|
||||
# Export
|
||||
async def oas3_openai_text_to_image(text, size_type: str = "1024x1024"):
|
||||
"""Text to image
|
||||
|
||||
:param text: The text used for image conversion.
|
||||
:param size_type: One of ['256x256', '512x512', '1024x1024']
|
||||
:return: The image data is returned in Base64 encoding.
|
||||
"""
|
||||
if not text:
|
||||
return ""
|
||||
return await OpenAIText2Image().text_2_image(text, size_type=size_type)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Config()
|
||||
loop = asyncio.new_event_loop()
|
||||
task = loop.create_task(oas3_openai_text_to_image("Panda emoji"))
|
||||
v = loop.run_until_complete(task)
|
||||
print(v)
|
||||
|
|
@ -6,7 +6,6 @@ import asyncio
|
|||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
from os.path import join
|
||||
from typing import List
|
||||
|
||||
|
|
@ -14,8 +13,7 @@ from aiohttp import ClientSession
|
|||
from PIL import Image, PngImagePlugin
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
|
||||
# from metagpt.const import WORKSPACE_ROOT
|
||||
from metagpt.const import SD_OUTPUT_FILE_REPO
|
||||
from metagpt.logs import logger
|
||||
|
||||
payload = {
|
||||
|
|
@ -79,10 +77,10 @@ class SDEngine:
|
|||
return self.payload
|
||||
|
||||
def _save(self, imgs, save_name=""):
|
||||
save_dir = CONFIG.workspace_path / "resources" / "SD_Output"
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
batch_decode_base64_to_image(imgs, save_dir, save_name=save_name)
|
||||
save_dir = CONFIG.workspace_path / SD_OUTPUT_FILE_REPO
|
||||
if not save_dir.exists():
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
batch_decode_base64_to_image(imgs, str(save_dir), save_name=save_name)
|
||||
|
||||
async def run_t2i(self, prompts: List):
|
||||
# Asynchronously run the SD API for multiple prompts
|
||||
|
|
|
|||
|
|
@ -1,4 +1,7 @@
|
|||
#!/usr/bin/env python
|
||||
"""
|
||||
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
|
@ -17,14 +20,16 @@ class WebBrowserEngine:
|
|||
run_func: Callable[..., Coroutine[Any, Any, WebPage | list[WebPage]]] | None = None,
|
||||
):
|
||||
engine = engine or CONFIG.web_browser_engine
|
||||
if engine is None:
|
||||
raise NotImplementedError
|
||||
|
||||
if engine == WebBrowserEngineType.PLAYWRIGHT:
|
||||
if WebBrowserEngineType(engine) is WebBrowserEngineType.PLAYWRIGHT:
|
||||
module = "metagpt.tools.web_browser_engine_playwright"
|
||||
run_func = importlib.import_module(module).PlaywrightWrapper().run
|
||||
elif engine == WebBrowserEngineType.SELENIUM:
|
||||
elif WebBrowserEngineType(engine) is WebBrowserEngineType.SELENIUM:
|
||||
module = "metagpt.tools.web_browser_engine_selenium"
|
||||
run_func = importlib.import_module(module).SeleniumWrapper().run
|
||||
elif engine == WebBrowserEngineType.CUSTOM:
|
||||
elif WebBrowserEngineType(engine) is WebBrowserEngineType.CUSTOM:
|
||||
run_func = run_func
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
|
@ -47,6 +52,6 @@ if __name__ == "__main__":
|
|||
import fire
|
||||
|
||||
async def main(url: str, *urls: str, engine_type: Literal["playwright", "selenium"] = "playwright", **kwargs):
|
||||
return await WebBrowserEngine(WebBrowserEngineType(engine_type), **kwargs).run(url, *urls)
|
||||
return await WebBrowserEngine(engine=WebBrowserEngineType(engine_type), **kwargs).run(url, *urls)
|
||||
|
||||
fire.Fire(main)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,8 @@
|
|||
#!/usr/bin/env python
|
||||
"""
|
||||
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
|
@ -144,6 +148,6 @@ if __name__ == "__main__":
|
|||
import fire
|
||||
|
||||
async def main(url: str, *urls: str, browser_type: str = "chromium", **kwargs):
|
||||
return await PlaywrightWrapper(browser_type, **kwargs).run(url, *urls)
|
||||
return await PlaywrightWrapper(browser_type=browser_type, **kwargs).run(url, *urls)
|
||||
|
||||
fire.Fire(main)
|
||||
|
|
|
|||
|
|
@ -1,11 +1,15 @@
|
|||
#!/usr/bin/env python
|
||||
"""
|
||||
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import importlib
|
||||
from concurrent import futures
|
||||
from copy import deepcopy
|
||||
from typing import Literal
|
||||
from typing import Dict, Literal
|
||||
|
||||
from selenium.webdriver.common.by import By
|
||||
from selenium.webdriver.support import expected_conditions as EC
|
||||
|
|
@ -29,6 +33,7 @@ class SeleniumWrapper:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
options: Dict,
|
||||
browser_type: Literal["chrome", "firefox", "edge", "ie"] | None = None,
|
||||
launch_kwargs: dict | None = None,
|
||||
*,
|
||||
|
|
@ -120,6 +125,6 @@ if __name__ == "__main__":
|
|||
import fire
|
||||
|
||||
async def main(url: str, *urls: str, browser_type: str = "chrome", **kwargs):
|
||||
return await SeleniumWrapper(browser_type, **kwargs).run(url, *urls)
|
||||
return await SeleniumWrapper(browser_type=browser_type, **kwargs).run(url, *urls)
|
||||
|
||||
fire.Fire(main)
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ import sys
|
|||
import traceback
|
||||
import typing
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Tuple, Union, get_args, get_origin
|
||||
from typing import Any, Callable, List, Tuple, Union, get_args, get_origin
|
||||
|
||||
import aiofiles
|
||||
import loguru
|
||||
|
|
@ -48,7 +48,7 @@ def check_cmd_exists(command) -> int:
|
|||
return result
|
||||
|
||||
|
||||
def require_python_version(req_version: tuple[int]) -> bool:
|
||||
def require_python_version(req_version: Tuple) -> bool:
|
||||
if not (2 <= len(req_version) <= 3):
|
||||
raise ValueError("req_version should be (3, 9) or (3, 10, 13)")
|
||||
return True if sys.version_info > req_version else False
|
||||
|
|
@ -367,7 +367,7 @@ def get_class_name(cls) -> str:
|
|||
return f"{cls.__module__}.{cls.__name__}"
|
||||
|
||||
|
||||
def any_to_str(val: str | typing.Callable) -> str:
|
||||
def any_to_str(val: str | Callable) -> str:
|
||||
"""Return the class name or the class name of the object, or 'val' if it's a string type."""
|
||||
if isinstance(val, str):
|
||||
return val
|
||||
|
|
@ -406,6 +406,21 @@ def is_subscribed(message: "Message", tags: set):
|
|||
return False
|
||||
|
||||
|
||||
def any_to_name(val):
|
||||
"""
|
||||
Convert a value to its name by extracting the last part of the dotted path.
|
||||
|
||||
:param val: The value to convert.
|
||||
|
||||
:return: The name of the value.
|
||||
"""
|
||||
return any_to_str(val).split(".")[-1]
|
||||
|
||||
|
||||
def concat_namespace(*args) -> str:
|
||||
return ":".join(str(value) for value in args)
|
||||
|
||||
|
||||
def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> typing.Callable[["RetryCallState"], None]:
|
||||
"""
|
||||
Generates a logging function to be used after a call is retried.
|
||||
|
|
@ -520,3 +535,20 @@ async def aread(file_path: str) -> str:
|
|||
async with aiofiles.open(str(file_path), mode="r") as reader:
|
||||
content = await reader.read()
|
||||
return content
|
||||
|
||||
|
||||
async def read_file_block(filename: str | Path, lineno: int, end_lineno: int):
|
||||
if not Path(filename).exists():
|
||||
return ""
|
||||
lines = []
|
||||
async with aiofiles.open(str(filename), mode="r") as reader:
|
||||
ix = 0
|
||||
while ix < end_lineno:
|
||||
ix += 1
|
||||
line = await reader.readline()
|
||||
if ix < lineno:
|
||||
continue
|
||||
if ix > end_lineno:
|
||||
break
|
||||
lines.append(line)
|
||||
return "".join(lines)
|
||||
|
|
|
|||
82
metagpt/utils/cost_manager.py
Normal file
82
metagpt/utils/cost_manager.py
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/8/28
|
||||
@Author : mashenquan
|
||||
@File : openai.py
|
||||
@Desc : mashenquan, 2023/8/28. Separate the `CostManager` class to support user-level cost accounting.
|
||||
"""
|
||||
|
||||
from typing import NamedTuple
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.token_counter import TOKEN_COSTS
|
||||
|
||||
|
||||
class Costs(NamedTuple):
|
||||
total_prompt_tokens: int
|
||||
total_completion_tokens: int
|
||||
total_cost: float
|
||||
total_budget: float
|
||||
|
||||
|
||||
class CostManager(BaseModel):
|
||||
"""Calculate the overhead of using the interface."""
|
||||
|
||||
total_prompt_tokens: int = 0
|
||||
total_completion_tokens: int = 0
|
||||
total_budget: float = 0
|
||||
max_budget: float = 10.0
|
||||
total_cost: float = 0
|
||||
|
||||
def update_cost(self, prompt_tokens, completion_tokens, model):
|
||||
"""
|
||||
Update the total cost, prompt tokens, and completion tokens.
|
||||
|
||||
Args:
|
||||
prompt_tokens (int): The number of tokens used in the prompt.
|
||||
completion_tokens (int): The number of tokens used in the completion.
|
||||
model (str): The model used for the API call.
|
||||
"""
|
||||
self.total_prompt_tokens += prompt_tokens
|
||||
self.total_completion_tokens += completion_tokens
|
||||
cost = (
|
||||
prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model]["completion"]
|
||||
) / 1000
|
||||
self.total_cost += cost
|
||||
logger.info(
|
||||
f"Total running cost: ${self.total_cost:.3f} | Max budget: ${self.max_budget:.3f} | "
|
||||
f"Current cost: ${cost:.3f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
|
||||
)
|
||||
|
||||
def get_total_prompt_tokens(self):
|
||||
"""
|
||||
Get the total number of prompt tokens.
|
||||
|
||||
Returns:
|
||||
int: The total number of prompt tokens.
|
||||
"""
|
||||
return self.total_prompt_tokens
|
||||
|
||||
def get_total_completion_tokens(self):
|
||||
"""
|
||||
Get the total number of completion tokens.
|
||||
|
||||
Returns:
|
||||
int: The total number of completion tokens.
|
||||
"""
|
||||
return self.total_completion_tokens
|
||||
|
||||
def get_total_cost(self):
|
||||
"""
|
||||
Get the total cost of API calls.
|
||||
|
||||
Returns:
|
||||
float: The total cost of API calls.
|
||||
"""
|
||||
return self.total_cost
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
"""Get all costs"""
|
||||
return Costs(self.total_prompt_tokens, self.total_completion_tokens, self.total_cost, self.total_budget)
|
||||
84
metagpt/utils/di_graph_repository.py
Normal file
84
metagpt/utils/di_graph_repository.py
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/19
|
||||
@Author : mashenquan
|
||||
@File : di_graph_repository.py
|
||||
@Desc : Graph repository based on DiGraph
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import aiofiles
|
||||
import networkx
|
||||
|
||||
from metagpt.utils.graph_repository import SPO, GraphRepository
|
||||
|
||||
|
||||
class DiGraphRepository(GraphRepository):
|
||||
def __init__(self, name: str, **kwargs):
|
||||
super().__init__(name=name, **kwargs)
|
||||
self._repo = networkx.DiGraph()
|
||||
|
||||
async def insert(self, subject: str, predicate: str, object_: str):
|
||||
self._repo.add_edge(subject, object_, predicate=predicate)
|
||||
|
||||
async def upsert(self, subject: str, predicate: str, object_: str):
|
||||
pass
|
||||
|
||||
async def update(self, subject: str, predicate: str, object_: str):
|
||||
pass
|
||||
|
||||
async def select(self, subject: str = None, predicate: str = None, object_: str = None) -> List[SPO]:
|
||||
result = []
|
||||
for s, o, p in self._repo.edges(data="predicate"):
|
||||
if subject and subject != s:
|
||||
continue
|
||||
if predicate and predicate != p:
|
||||
continue
|
||||
if object_ and object_ != o:
|
||||
continue
|
||||
result.append(SPO(subject=s, predicate=p, object_=o))
|
||||
return result
|
||||
|
||||
def json(self) -> str:
|
||||
m = networkx.node_link_data(self._repo)
|
||||
data = json.dumps(m)
|
||||
return data
|
||||
|
||||
async def save(self, path: str | Path = None):
|
||||
data = self.json()
|
||||
path = path or self._kwargs.get("root")
|
||||
if not path.exists():
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
pathname = Path(path) / self.name
|
||||
async with aiofiles.open(str(pathname.with_suffix(".json")), mode="w", encoding="utf-8") as writer:
|
||||
await writer.write(data)
|
||||
|
||||
async def load(self, pathname: str | Path):
|
||||
async with aiofiles.open(str(pathname), mode="r", encoding="utf-8") as reader:
|
||||
data = await reader.read(-1)
|
||||
m = json.loads(data)
|
||||
self._repo = networkx.node_link_graph(m)
|
||||
|
||||
@staticmethod
|
||||
async def load_from(pathname: str | Path) -> GraphRepository:
|
||||
pathname = Path(pathname)
|
||||
name = pathname.with_suffix("").name
|
||||
root = pathname.parent
|
||||
graph = DiGraphRepository(name=name, root=root)
|
||||
if pathname.exists():
|
||||
await graph.load(pathname=pathname)
|
||||
return graph
|
||||
|
||||
@property
|
||||
def root(self) -> str:
|
||||
return self._kwargs.get("root")
|
||||
|
||||
@property
|
||||
def pathname(self) -> Path:
|
||||
p = Path(self.root) / self.name
|
||||
return p.with_suffix(".json")
|
||||
150
metagpt/utils/graph_repository.py
Normal file
150
metagpt/utils/graph_repository.py
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/19
|
||||
@Author : mashenquan
|
||||
@File : graph_repository.py
|
||||
@Desc : Superclass for graph repository.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.repo_parser import ClassInfo, RepoFileInfo
|
||||
from metagpt.utils.common import concat_namespace
|
||||
|
||||
|
||||
class GraphKeyword:
|
||||
IS = "is"
|
||||
CLASS = "class"
|
||||
FUNCTION = "function"
|
||||
SOURCE_CODE = "source_code"
|
||||
NULL = "<null>"
|
||||
GLOBAL_VARIABLE = "global_variable"
|
||||
CLASS_FUNCTION = "class_function"
|
||||
CLASS_PROPERTY = "class_property"
|
||||
HAS_CLASS = "has_class"
|
||||
HAS_PAGE_INFO = "has_page_info"
|
||||
HAS_CLASS_VIEW = "has_class_view"
|
||||
HAS_SEQUENCE_VIEW = "has_sequence_view"
|
||||
HAS_ARGS_DESC = "has_args_desc"
|
||||
HAS_TYPE_DESC = "has_type_desc"
|
||||
|
||||
|
||||
class SPO(BaseModel):
|
||||
subject: str
|
||||
predicate: str
|
||||
object_: str
|
||||
|
||||
|
||||
class GraphRepository(ABC):
|
||||
def __init__(self, name: str, **kwargs):
|
||||
self._repo_name = name
|
||||
self._kwargs = kwargs
|
||||
|
||||
@abstractmethod
|
||||
async def insert(self, subject: str, predicate: str, object_: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def upsert(self, subject: str, predicate: str, object_: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def update(self, subject: str, predicate: str, object_: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def select(self, subject: str = None, predicate: str = None, object_: str = None) -> List[SPO]:
|
||||
pass
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._repo_name
|
||||
|
||||
@staticmethod
|
||||
async def update_graph_db_with_file_info(graph_db: "GraphRepository", file_info: RepoFileInfo):
|
||||
await graph_db.insert(subject=file_info.file, predicate=GraphKeyword.IS, object_=GraphKeyword.SOURCE_CODE)
|
||||
file_types = {".py": "python", ".js": "javascript"}
|
||||
file_type = file_types.get(Path(file_info.file).suffix, GraphKeyword.NULL)
|
||||
await graph_db.insert(subject=file_info.file, predicate=GraphKeyword.IS, object_=file_type)
|
||||
for c in file_info.classes:
|
||||
class_name = c.get("name", "")
|
||||
await graph_db.insert(
|
||||
subject=file_info.file,
|
||||
predicate=GraphKeyword.HAS_CLASS,
|
||||
object_=concat_namespace(file_info.file, class_name),
|
||||
)
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(file_info.file, class_name),
|
||||
predicate=GraphKeyword.IS,
|
||||
object_=GraphKeyword.CLASS,
|
||||
)
|
||||
methods = c.get("methods", [])
|
||||
for fn in methods:
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(file_info.file, class_name, fn),
|
||||
predicate=GraphKeyword.IS,
|
||||
object_=GraphKeyword.CLASS_FUNCTION,
|
||||
)
|
||||
for f in file_info.functions:
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(file_info.file, f), predicate=GraphKeyword.IS, object_=GraphKeyword.FUNCTION
|
||||
)
|
||||
for g in file_info.globals:
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(file_info.file, g),
|
||||
predicate=GraphKeyword.IS,
|
||||
object_=GraphKeyword.GLOBAL_VARIABLE,
|
||||
)
|
||||
for code_block in file_info.page_info:
|
||||
if code_block.tokens:
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(file_info.file, *code_block.tokens),
|
||||
predicate=GraphKeyword.HAS_PAGE_INFO,
|
||||
object_=code_block.json(ensure_ascii=False),
|
||||
)
|
||||
for k, v in code_block.properties.items():
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(file_info.file, k, v),
|
||||
predicate=GraphKeyword.HAS_PAGE_INFO,
|
||||
object_=code_block.json(ensure_ascii=False),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_graph_db_with_class_views(graph_db: "GraphRepository", class_views: List[ClassInfo]):
|
||||
for c in class_views:
|
||||
filename, class_name = c.package.split(":", 1)
|
||||
await graph_db.insert(subject=filename, predicate=GraphKeyword.IS, object_=GraphKeyword.SOURCE_CODE)
|
||||
file_types = {".py": "python", ".js": "javascript"}
|
||||
file_type = file_types.get(Path(filename).suffix, GraphKeyword.NULL)
|
||||
await graph_db.insert(subject=filename, predicate=GraphKeyword.IS, object_=file_type)
|
||||
await graph_db.insert(subject=filename, predicate=GraphKeyword.HAS_CLASS, object_=class_name)
|
||||
await graph_db.insert(
|
||||
subject=c.package,
|
||||
predicate=GraphKeyword.IS,
|
||||
object_=GraphKeyword.CLASS,
|
||||
)
|
||||
for vn, vt in c.attributes.items():
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(c.package, vn),
|
||||
predicate=GraphKeyword.IS,
|
||||
object_=GraphKeyword.CLASS_PROPERTY,
|
||||
)
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(c.package, vn), predicate=GraphKeyword.HAS_TYPE_DESC, object_=vt
|
||||
)
|
||||
for fn, desc in c.methods.items():
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(c.package, fn),
|
||||
predicate=GraphKeyword.IS,
|
||||
object_=GraphKeyword.CLASS_FUNCTION,
|
||||
)
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(c.package, fn),
|
||||
predicate=GraphKeyword.HAS_ARGS_DESC,
|
||||
object_=desc,
|
||||
)
|
||||
|
|
@ -18,17 +18,15 @@ from metagpt.config import CONFIG
|
|||
|
||||
def make_sk_kernel():
|
||||
kernel = sk.Kernel()
|
||||
if CONFIG.openai_api_type == "azure":
|
||||
if CONFIG.OPENAI_API_TYPE == "azure":
|
||||
kernel.add_chat_service(
|
||||
"chat_completion",
|
||||
AzureChatCompletion(
|
||||
deployment_name=CONFIG.deployment_name, endpoint=CONFIG.openai_base_url, api_key=CONFIG.openai_api_key
|
||||
),
|
||||
AzureChatCompletion(CONFIG.DEPLOYMENT_NAME, CONFIG.OPENAI_BASE_URL, CONFIG.OPENAI_API_KEY),
|
||||
)
|
||||
else:
|
||||
kernel.add_chat_service(
|
||||
"chat_completion",
|
||||
OpenAIChatCompletion(model_id=CONFIG.openai_api_model, api_key=CONFIG.openai_api_key),
|
||||
OpenAIChatCompletion(CONFIG.OPENAI_API_MODEL, CONFIG.OPENAI_API_KEY),
|
||||
)
|
||||
|
||||
return kernel
|
||||
|
|
|
|||
|
|
@ -4,11 +4,14 @@
|
|||
@Time : 2023/7/4 10:53
|
||||
@Author : alexanderwu alitrack
|
||||
@File : mermaid.py
|
||||
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
|
||||
"""
|
||||
import asyncio
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import aiofiles
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import METAGPT_ROOT
|
||||
from metagpt.logs import logger
|
||||
|
|
@ -29,7 +32,9 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
|
|||
if dir_name and not os.path.exists(dir_name):
|
||||
os.makedirs(dir_name)
|
||||
tmp = Path(f"{output_file_without_suffix}.mmd")
|
||||
tmp.write_text(mermaid_code, encoding="utf-8")
|
||||
async with aiofiles.open(tmp, "w", encoding="utf-8") as f:
|
||||
await f.write(mermaid_code)
|
||||
# tmp.write_text(mermaid_code, encoding="utf-8")
|
||||
|
||||
engine = CONFIG.mermaid_engine.lower()
|
||||
if engine == "nodejs":
|
||||
|
|
@ -88,7 +93,8 @@ async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048,
|
|||
return 0
|
||||
|
||||
|
||||
MMC1 = """classDiagram
|
||||
MMC1 = """
|
||||
classDiagram
|
||||
class Main {
|
||||
-SearchEngine search_engine
|
||||
+main() str
|
||||
|
|
@ -118,9 +124,11 @@ MMC1 = """classDiagram
|
|||
SearchEngine --> Index
|
||||
SearchEngine --> Ranking
|
||||
SearchEngine --> Summary
|
||||
Index --> KnowledgeBase"""
|
||||
Index --> KnowledgeBase
|
||||
"""
|
||||
|
||||
MMC2 = """sequenceDiagram
|
||||
MMC2 = """
|
||||
sequenceDiagram
|
||||
participant M as Main
|
||||
participant SE as SearchEngine
|
||||
participant I as Index
|
||||
|
|
@ -136,11 +144,11 @@ MMC2 = """sequenceDiagram
|
|||
R-->>SE: return ranked_results
|
||||
SE->>S: summarize_results(ranked_results)
|
||||
S-->>SE: return summary
|
||||
SE-->>M: return summary"""
|
||||
|
||||
SE-->>M: return summary
|
||||
"""
|
||||
|
||||
if __name__ == "__main__":
|
||||
loop = asyncio.new_event_loop()
|
||||
result = loop.run_until_complete(mermaid_to_file(MMC1, METAGPT_ROOT / f"{CONFIG.mermaid_engine}/1"))
|
||||
result = loop.run_until_complete(mermaid_to_file(MMC2, METAGPT_ROOT / f"{CONFIG.mermaid_engine}/1"))
|
||||
result = loop.run_until_complete(mermaid_to_file(MMC2, METAGPT_ROOT / f"{CONFIG.mermaid_engine}/2"))
|
||||
loop.close()
|
||||
|
|
|
|||
219
metagpt/utils/redis.py
Normal file
219
metagpt/utils/redis.py
Normal file
|
|
@ -0,0 +1,219 @@
|
|||
# !/usr/bin/python3
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Hui
|
||||
# @Desc: { redis client }
|
||||
# @Date: 2022/11/28 10:12
|
||||
import json
|
||||
import traceback
|
||||
from datetime import timedelta
|
||||
from enum import Enum
|
||||
from typing import Awaitable, Callable, Dict, Optional, Union
|
||||
|
||||
from redis import asyncio as aioredis
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
class RedisTypeEnum(Enum):
|
||||
"""Redis 数据类型"""
|
||||
|
||||
String = "String"
|
||||
List = "List"
|
||||
Hash = "Hash"
|
||||
Set = "Set"
|
||||
ZSet = "ZSet"
|
||||
|
||||
|
||||
def make_url(
|
||||
dialect: str,
|
||||
*,
|
||||
user: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[Union[str, int]] = None,
|
||||
name: Optional[Union[str, int]] = None,
|
||||
) -> str:
|
||||
url_parts = [f"{dialect}://"]
|
||||
if user or password:
|
||||
if user:
|
||||
url_parts.append(user)
|
||||
if password:
|
||||
url_parts.append(f":{password}")
|
||||
url_parts.append("@")
|
||||
|
||||
if not host and not dialect.startswith("sqlite"):
|
||||
host = "127.0.0.1"
|
||||
|
||||
if host:
|
||||
url_parts.append(f"{host}")
|
||||
if port:
|
||||
url_parts.append(f":{port}")
|
||||
|
||||
# 比如redis可能传入0
|
||||
if name is not None:
|
||||
url_parts.append(f"/{name}")
|
||||
return "".join(url_parts)
|
||||
|
||||
|
||||
class RedisAsyncClient(aioredis.Redis):
|
||||
"""异步的客户端
|
||||
例子::
|
||||
|
||||
rdb = RedisAsyncClient()
|
||||
print(rdb.url)
|
||||
|
||||
Args:
|
||||
host: 服务器地址
|
||||
port: 服务器端口
|
||||
user: 用户名
|
||||
db: 数据库
|
||||
password: 密码
|
||||
decode_responses: 字符串输入被编码成utf8存储在Redis里了,而取出来的时候还是被编码后的bytes,需要显示的decode才能变成字符串
|
||||
health_check_interval: 定时检测连接,防止出现ConnectionErrors (104, Connection reset by peer)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str = "localhost",
|
||||
port: int = 6379,
|
||||
db: int = 0,
|
||||
password: str = None,
|
||||
decode_responses=True,
|
||||
health_check_interval=10,
|
||||
socket_connect_timeout=5,
|
||||
retry_on_timeout=True,
|
||||
socket_keepalive=True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
host=host,
|
||||
port=port,
|
||||
db=db,
|
||||
password=password,
|
||||
decode_responses=decode_responses,
|
||||
health_check_interval=health_check_interval,
|
||||
socket_connect_timeout=socket_connect_timeout,
|
||||
retry_on_timeout=retry_on_timeout,
|
||||
socket_keepalive=socket_keepalive,
|
||||
**kwargs,
|
||||
)
|
||||
self.url = make_url("redis", host=host, port=port, name=db, password=password)
|
||||
|
||||
|
||||
class RedisCacheInfo(object):
|
||||
"""统一缓存信息类"""
|
||||
|
||||
def __init__(self, key, timeout: Union[int, timedelta] = timedelta(seconds=60), data_type=RedisTypeEnum.String):
|
||||
"""
|
||||
缓存信息类初始化
|
||||
Args:
|
||||
key: 缓存的key
|
||||
timeout: 缓存过期时间, 单位秒
|
||||
data_type: 缓存采用的数据结构 (不传并不影响,用于标记业务采用的是什么数据结构)
|
||||
"""
|
||||
self.key = key
|
||||
self.timeout = timeout
|
||||
self.data_type = data_type
|
||||
|
||||
def __str__(self):
|
||||
return f"cache key {self.key} timeout {self.timeout}s"
|
||||
|
||||
|
||||
class RedisManager:
|
||||
client: RedisAsyncClient = None
|
||||
|
||||
@classmethod
|
||||
def init_redis_conn(cls, host, port, password, db):
|
||||
"""初始化redis 连接"""
|
||||
if cls.client is None:
|
||||
cls.client = RedisAsyncClient(host=host, port=port, password=password, db=db)
|
||||
|
||||
@classmethod
|
||||
async def set_with_cache_info(cls, redis_cache_info: RedisCacheInfo, value):
|
||||
"""
|
||||
根据 RedisCacheInfo 设置 Redis 缓存
|
||||
:param redis_cache_info: RedisCacheInfo缓存信息对象
|
||||
:param value: 缓存的值
|
||||
:return:
|
||||
"""
|
||||
await cls.client.setex(redis_cache_info.key, redis_cache_info.timeout, value)
|
||||
|
||||
@classmethod
|
||||
async def get_with_cache_info(cls, redis_cache_info: RedisCacheInfo):
|
||||
"""
|
||||
根据 RedisCacheInfo 获取 Redis 缓存
|
||||
:param redis_cache_info: RedisCacheInfo 缓存信息对象
|
||||
:return:
|
||||
"""
|
||||
cache_info = await cls.client.get(redis_cache_info.key)
|
||||
return cache_info
|
||||
|
||||
@classmethod
|
||||
async def del_with_cache_info(cls, redis_cache_info: RedisCacheInfo):
|
||||
"""
|
||||
根据 RedisCacheInfo 删除 Redis 缓存
|
||||
:param redis_cache_info: RedisCacheInfo缓存信息对象
|
||||
:return:
|
||||
"""
|
||||
await cls.client.delete(redis_cache_info.key)
|
||||
|
||||
@staticmethod
|
||||
async def get_or_set_cache(cache_info: RedisCacheInfo, fetch_data_func: Callable[[], Awaitable[dict]]) -> dict:
|
||||
"""
|
||||
获取缓存数据,如果缓存不存在,则从提供的函数中获取并设置缓存
|
||||
当前版本仅支持 json 形式的 string 格式数据
|
||||
"""
|
||||
|
||||
serialized_data = await RedisManager.get_with_cache_info(cache_info)
|
||||
|
||||
if serialized_data:
|
||||
return json.loads(serialized_data)
|
||||
|
||||
data = await fetch_data_func()
|
||||
try:
|
||||
serialized_data = json.dumps(data)
|
||||
await RedisManager.set_with_cache_info(cache_info, serialized_data)
|
||||
except Exception as e:
|
||||
logger.warning(f"数据 {data} 通过 json 进行序列化缓存失败:{e}")
|
||||
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def is_valid(cls):
|
||||
return cls.client is not None
|
||||
|
||||
|
||||
class Redis:
|
||||
def __init__(self, conf: Dict = None):
|
||||
try:
|
||||
host = CONFIG.REDIS_HOST
|
||||
port = int(CONFIG.REDIS_PORT)
|
||||
pwd = CONFIG.REDIS_PASSWORD
|
||||
db = CONFIG.REDIS_DB
|
||||
RedisManager.init_redis_conn(host=host, port=port, password=pwd, db=db)
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis initialization has failed:{e}")
|
||||
|
||||
def is_valid(self):
|
||||
return RedisManager.is_valid()
|
||||
|
||||
async def get(self, key: str) -> str:
|
||||
if not self.is_valid() or not key:
|
||||
return None
|
||||
try:
|
||||
v = await RedisManager.get_with_cache_info(redis_cache_info=RedisCacheInfo(key=key))
|
||||
return v
|
||||
except Exception as e:
|
||||
logger.exception(f"{e}, stack:{traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
async def set(self, key: str, data: str, timeout_sec: int):
|
||||
if not self.is_valid() or not key:
|
||||
return
|
||||
try:
|
||||
await RedisManager.set_with_cache_info(
|
||||
redis_cache_info=RedisCacheInfo(key=key, timeout=timeout_sec), value=data
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"{e}, stack:{traceback.format_exc()}")
|
||||
|
|
@ -230,9 +230,11 @@ def run_after_exp_and_passon_next_retry(logger: "loguru.Logger") -> Callable[["R
|
|||
elif retry_state.kwargs:
|
||||
func_param_output = retry_state.kwargs.get("output", "")
|
||||
exp_str = str(retry_state.outcome.exception())
|
||||
|
||||
fix_str = "try to fix it, " if CONFIG.repair_llm_output else ""
|
||||
logger.warning(
|
||||
f"parse json from content inside [CONTENT][/CONTENT] failed at retry "
|
||||
f"{retry_state.attempt_number}, try to fix it, exp: {exp_str}"
|
||||
f"{retry_state.attempt_number}, {fix_str}exp: {exp_str}"
|
||||
)
|
||||
|
||||
repaired_output = repair_invalid_json(func_param_output, exp_str)
|
||||
|
|
|
|||
170
metagpt/utils/s3.py
Normal file
170
metagpt/utils/s3.py
Normal file
|
|
@ -0,0 +1,170 @@
|
|||
import base64
|
||||
import os.path
|
||||
import traceback
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import aioboto3
|
||||
import aiofiles
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import BASE64_FORMAT
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
class S3:
|
||||
"""A class for interacting with Amazon S3 storage."""
|
||||
|
||||
def __init__(self):
|
||||
self.session = aioboto3.Session()
|
||||
self.auth_config = {
|
||||
"service_name": "s3",
|
||||
"aws_access_key_id": CONFIG.S3_ACCESS_KEY,
|
||||
"aws_secret_access_key": CONFIG.S3_SECRET_KEY,
|
||||
"endpoint_url": CONFIG.S3_ENDPOINT_URL,
|
||||
}
|
||||
|
||||
async def upload_file(
|
||||
self,
|
||||
bucket: str,
|
||||
local_path: str,
|
||||
object_name: str,
|
||||
) -> None:
|
||||
"""Upload a file from the local path to the specified path of the storage bucket specified in s3.
|
||||
|
||||
Args:
|
||||
bucket: The name of the S3 storage bucket.
|
||||
local_path: The local file path, including the file name.
|
||||
object_name: The complete path of the uploaded file to be stored in S3, including the file name.
|
||||
|
||||
Raises:
|
||||
Exception: If an error occurs during the upload process, an exception is raised.
|
||||
"""
|
||||
try:
|
||||
async with self.session.client(**self.auth_config) as client:
|
||||
async with aiofiles.open(local_path, mode="rb") as reader:
|
||||
body = await reader.read()
|
||||
await client.put_object(Body=body, Bucket=bucket, Key=object_name)
|
||||
logger.info(f"Successfully uploaded the file to path {object_name} in bucket {bucket} of s3.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to upload the file to path {object_name} in bucket {bucket} of s3: {e}")
|
||||
raise e
|
||||
|
||||
async def get_object_url(
|
||||
self,
|
||||
bucket: str,
|
||||
object_name: str,
|
||||
) -> str:
|
||||
"""Get the URL for a downloadable or preview file stored in the specified S3 bucket.
|
||||
|
||||
Args:
|
||||
bucket: The name of the S3 storage bucket.
|
||||
object_name: The complete path of the file stored in S3, including the file name.
|
||||
|
||||
Returns:
|
||||
The URL for the downloadable or preview file.
|
||||
|
||||
Raises:
|
||||
Exception: If an error occurs while retrieving the URL, an exception is raised.
|
||||
"""
|
||||
try:
|
||||
async with self.session.client(**self.auth_config) as client:
|
||||
file = await client.get_object(Bucket=bucket, Key=object_name)
|
||||
return str(file["Body"].url)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get the url for a downloadable or preview file: {e}")
|
||||
raise e
|
||||
|
||||
async def get_object(
|
||||
self,
|
||||
bucket: str,
|
||||
object_name: str,
|
||||
) -> bytes:
|
||||
"""Get the binary data of a file stored in the specified S3 bucket.
|
||||
|
||||
Args:
|
||||
bucket: The name of the S3 storage bucket.
|
||||
object_name: The complete path of the file stored in S3, including the file name.
|
||||
|
||||
Returns:
|
||||
The binary data of the requested file.
|
||||
|
||||
Raises:
|
||||
Exception: If an error occurs while retrieving the file data, an exception is raised.
|
||||
"""
|
||||
try:
|
||||
async with self.session.client(**self.auth_config) as client:
|
||||
s3_object = await client.get_object(Bucket=bucket, Key=object_name)
|
||||
return await s3_object["Body"].read()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get the binary data of the file: {e}")
|
||||
raise e
|
||||
|
||||
async def download_file(
|
||||
self, bucket: str, object_name: str, local_path: str, chunk_size: Optional[int] = 128 * 1024
|
||||
) -> None:
|
||||
"""Download an S3 object to a local file.
|
||||
|
||||
Args:
|
||||
bucket: The name of the S3 storage bucket.
|
||||
object_name: The complete path of the file stored in S3, including the file name.
|
||||
local_path: The local file path where the S3 object will be downloaded.
|
||||
chunk_size: The size of data chunks to read and write at a time. Default is 128 KB.
|
||||
|
||||
Raises:
|
||||
Exception: If an error occurs during the download process, an exception is raised.
|
||||
"""
|
||||
try:
|
||||
async with self.session.client(**self.auth_config) as client:
|
||||
s3_object = await client.get_object(Bucket=bucket, Key=object_name)
|
||||
stream = s3_object["Body"]
|
||||
async with aiofiles.open(local_path, mode="wb") as writer:
|
||||
while True:
|
||||
file_data = await stream.read(chunk_size)
|
||||
if not file_data:
|
||||
break
|
||||
await writer.write(file_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download the file from S3: {e}")
|
||||
raise e
|
||||
|
||||
async def cache(self, data: str, file_ext: str, format: str = "") -> str:
|
||||
"""Save data to remote S3 and return url"""
|
||||
object_name = uuid.uuid4().hex + file_ext
|
||||
path = Path(__file__).parent
|
||||
pathname = path / object_name
|
||||
try:
|
||||
async with aiofiles.open(str(pathname), mode="wb") as file:
|
||||
if format == BASE64_FORMAT:
|
||||
data = base64.b64decode(data)
|
||||
await file.write(data)
|
||||
|
||||
bucket = CONFIG.S3_BUCKET
|
||||
object_pathname = CONFIG.S3_BUCKET or "system"
|
||||
object_pathname += f"/{object_name}"
|
||||
object_pathname = os.path.normpath(object_pathname)
|
||||
await self.upload_file(bucket=bucket, local_path=str(pathname), object_name=object_pathname)
|
||||
pathname.unlink(missing_ok=True)
|
||||
|
||||
return await self.get_object_url(bucket=bucket, object_name=object_pathname)
|
||||
except Exception as e:
|
||||
logger.exception(f"{e}, stack:{traceback.format_exc()}")
|
||||
pathname.unlink(missing_ok=True)
|
||||
return None
|
||||
|
||||
@property
|
||||
def is_valid(self):
|
||||
is_invalid = (
|
||||
not CONFIG.S3_ACCESS_KEY
|
||||
or CONFIG.S3_ACCESS_KEY == "YOUR_S3_ACCESS_KEY"
|
||||
or not CONFIG.S3_SECRET_KEY
|
||||
or CONFIG.S3_SECRET_KEY == "YOUR_S3_SECRET_KEY"
|
||||
or not CONFIG.S3_ENDPOINT_URL
|
||||
or CONFIG.S3_ENDPOINT_URL == "YOUR_S3_ENDPOINT_URL"
|
||||
or not CONFIG.S3_BUCKET
|
||||
or CONFIG.S3_BUCKET == "YOUR_S3_BUCKET"
|
||||
)
|
||||
if is_invalid:
|
||||
logger.info("S3 is invalid")
|
||||
return not is_invalid
|
||||
|
|
@ -84,6 +84,13 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"):
|
|||
elif "gpt-4" == model:
|
||||
print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
|
||||
return count_message_tokens(messages, model="gpt-4-0613")
|
||||
elif "open-llm-model" == model:
|
||||
"""
|
||||
For self-hosted open_llm api, they include lots of different models. The message tokens calculation is
|
||||
inaccurate. It's a reference result.
|
||||
"""
|
||||
tokens_per_message = 0 # ignore conversation message template prefix
|
||||
tokens_per_name = 0
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"num_tokens_from_messages() is not implemented for model {model}. "
|
||||
|
|
@ -112,7 +119,11 @@ def count_string_tokens(string: str, model_name: str) -> int:
|
|||
Returns:
|
||||
int: The number of tokens in the text string.
|
||||
"""
|
||||
encoding = tiktoken.encoding_for_model(model_name)
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model_name)
|
||||
except KeyError:
|
||||
print("Warning: model not found. Using cl100k_base encoding.")
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
return len(encoding.encode(string))
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue