Merge branch 'main' into code_interpreter

This commit is contained in:
yzlin 2024-03-12 15:13:14 +08:00
commit 38f21137ec
146 changed files with 4466 additions and 1375 deletions

View file

@ -23,10 +23,10 @@ import platform
import re
import sys
import traceback
import typing
from io import BytesIO
from pathlib import Path
from typing import Any, Callable, List, Tuple, Union
from typing import Any, Callable, List, Literal, Tuple, Union
from urllib.parse import quote, unquote
import aiofiles
import loguru
@ -423,23 +423,109 @@ def is_send_to(message: "Message", addresses: set):
def any_to_name(val):
"""
Convert a value to its name by extracting the last part of the dotted path.
:param val: The value to convert.
:return: The name of the value.
"""
return any_to_str(val).split(".")[-1]
def concat_namespace(*args) -> str:
return ":".join(str(value) for value in args)
def concat_namespace(*args, delimiter: str = ":") -> str:
"""Concatenate fields to create a unique namespace prefix.
Example:
>>> concat_namespace('prefix', 'field1', 'field2', delimiter=":")
'prefix:field1:field2'
"""
return delimiter.join(str(value) for value in args)
def split_namespace(ns_class_name: str) -> List[str]:
return ns_class_name.split(":")
def split_namespace(ns_class_name: str, delimiter: str = ":", maxsplit: int = 1) -> List[str]:
"""Split a namespace-prefixed name into its namespace-prefix and name parts.
Example:
>>> split_namespace('prefix:classname')
['prefix', 'classname']
>>> split_namespace('prefix:module:class', delimiter=":", maxsplit=2)
['prefix', 'module', 'class']
"""
return ns_class_name.split(delimiter, maxsplit=maxsplit)
def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> typing.Callable[["RetryCallState"], None]:
def auto_namespace(name: str, delimiter: str = ":") -> str:
"""Automatically handle namespace-prefixed names.
If the input name is empty, returns a default namespace prefix and name.
If the input name is not namespace-prefixed, adds a default namespace prefix.
Otherwise, returns the input name unchanged.
Example:
>>> auto_namespace('classname')
'?:classname'
>>> auto_namespace('prefix:classname')
'prefix:classname'
>>> auto_namespace('')
'?:?'
>>> auto_namespace('?:custom')
'?:custom'
"""
if not name:
return f"?{delimiter}?"
v = split_namespace(name, delimiter=delimiter)
if len(v) < 2:
return f"?{delimiter}{name}"
return name
def add_affix(text: str, affix: Literal["brace", "url", "none"] = "brace"):
"""Add affix to encapsulate data.
Example:
>>> add_affix("data", affix="brace")
'{data}'
>>> add_affix("example.com", affix="url")
'%7Bexample.com%7D'
>>> add_affix("text", affix="none")
'text'
"""
mappings = {
"brace": lambda x: "{" + x + "}",
"url": lambda x: quote("{" + x + "}"),
}
encoder = mappings.get(affix, lambda x: x)
return encoder(text)
def remove_affix(text, affix: Literal["brace", "url", "none"] = "brace"):
"""Remove affix to extract encapsulated data.
Args:
text (str): The input text with affix to be removed.
affix (str, optional): The type of affix used. Defaults to "brace".
Supported affix types: "brace" for removing curly braces, "url" for URL decoding within curly braces.
Returns:
str: The text with affix removed.
Example:
>>> remove_affix('{data}', affix="brace")
'data'
>>> remove_affix('%7Bexample.com%7D', affix="url")
'example.com'
>>> remove_affix('text', affix="none")
'text'
"""
mappings = {"brace": lambda x: x[1:-1], "url": lambda x: unquote(x)[1:-1]}
decoder = mappings.get(affix, lambda x: x)
return decoder(text)
def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> Callable[["RetryCallState"], None]:
"""
Generates a logging function to be used after a call is retried.
@ -626,6 +712,54 @@ def list_files(root: str | Path) -> List[Path]:
return files
def parse_json_code_block(markdown_text: str) -> List[str]:
json_blocks = re.findall(r"```json(.*?)```", markdown_text, re.DOTALL)
return [v.strip() for v in json_blocks]
def remove_white_spaces(v: str) -> str:
return re.sub(r"(?<!['\"])\s|(?<=['\"])\s", "", v)
async def aread_bin(filename: str | Path) -> bytes:
"""Read binary file asynchronously.
Args:
filename (Union[str, Path]): The name or path of the file to be read.
Returns:
bytes: The content of the file as bytes.
Example:
>>> content = await aread_bin('example.txt')
b'This is the content of the file.'
>>> content = await aread_bin(Path('example.txt'))
b'This is the content of the file.'
"""
async with aiofiles.open(str(filename), mode="rb") as reader:
content = await reader.read()
return content
async def awrite_bin(filename: str | Path, data: bytes):
"""Write binary file asynchronously.
Args:
filename (Union[str, Path]): The name or path of the file to be written.
data (bytes): The binary data to be written to the file.
Example:
>>> await awrite_bin('output.bin', b'This is binary data.')
>>> await awrite_bin(Path('output.bin'), b'Another set of binary data.')
"""
pathname = Path(filename)
pathname.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(str(pathname), mode="wb") as writer:
await writer.write(data)
def is_coroutine_func(func: Callable) -> bool:
return inspect.iscoroutinefunction(func)
@ -689,3 +823,14 @@ def process_message(messages: Union[str, Message, list[dict], list[Message], lis
else:
raise ValueError(f"Only support message type are: str, Message, dict, but got {type(messages).__name__}!")
return processed_messages
def log_and_reraise(retry_state: RetryCallState):
logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}")
logger.warning(
"""
Recommend going to https://deepwisdom.feishu.cn/wiki/MsGnwQBjiif9c3koSJNcYaoSnu4#part-XdatdVlhEojeAfxaaEZcMV3ZniQ
See FAQ 5.8
"""
)
raise retry_state.outcome.exception()

View file

@ -6,12 +6,13 @@
@Desc : mashenquan, 2023/8/28. Separate the `CostManager` class to support user-level cost accounting.
"""
import re
from typing import NamedTuple
from pydantic import BaseModel
from metagpt.logs import logger
from metagpt.utils.token_counter import TOKEN_COSTS
from metagpt.utils.token_counter import FIREWORKS_GRADE_TOKEN_COSTS, TOKEN_COSTS
class Costs(NamedTuple):
@ -29,6 +30,7 @@ class CostManager(BaseModel):
total_budget: float = 0
max_budget: float = 10.0
total_cost: float = 0
token_costs: dict[str, dict[str, float]] = TOKEN_COSTS # different model's token cost
def update_cost(self, prompt_tokens, completion_tokens, model):
"""
@ -39,14 +41,17 @@ class CostManager(BaseModel):
completion_tokens (int): The number of tokens used in the completion.
model (str): The model used for the API call.
"""
if prompt_tokens + completion_tokens == 0 or not model:
return
self.total_prompt_tokens += prompt_tokens
self.total_completion_tokens += completion_tokens
if model not in TOKEN_COSTS:
if model not in self.token_costs:
logger.warning(f"Model {model} not found in TOKEN_COSTS.")
return
cost = (
prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model]["completion"]
prompt_tokens * self.token_costs[model]["prompt"]
+ completion_tokens * self.token_costs[model]["completion"]
) / 1000
self.total_cost += cost
logger.info(
@ -101,3 +106,44 @@ class TokenCostManager(CostManager):
self.total_prompt_tokens += prompt_tokens
self.total_completion_tokens += completion_tokens
logger.info(f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}")
class FireworksCostManager(CostManager):
def model_grade_token_costs(self, model: str) -> dict[str, float]:
def _get_model_size(model: str) -> float:
size = re.findall(".*-([0-9.]+)b", model)
size = float(size[0]) if len(size) > 0 else -1
return size
if "mixtral-8x7b" in model:
token_costs = FIREWORKS_GRADE_TOKEN_COSTS["mixtral-8x7b"]
else:
model_size = _get_model_size(model)
if 0 < model_size <= 16:
token_costs = FIREWORKS_GRADE_TOKEN_COSTS["16"]
elif 16 < model_size <= 80:
token_costs = FIREWORKS_GRADE_TOKEN_COSTS["80"]
else:
token_costs = FIREWORKS_GRADE_TOKEN_COSTS["-1"]
return token_costs
def update_cost(self, prompt_tokens: int, completion_tokens: int, model: str):
"""
Refs to `https://app.fireworks.ai/pricing` **Developer pricing**
Update the total cost, prompt tokens, and completion tokens.
Args:
prompt_tokens (int): The number of tokens used in the prompt.
completion_tokens (int): The number of tokens used in the completion.
model (str): The model used for the API call.
"""
self.total_prompt_tokens += prompt_tokens
self.total_completion_tokens += completion_tokens
token_costs = self.model_grade_token_costs(model)
cost = (prompt_tokens * token_costs["prompt"] + completion_tokens * token_costs["completion"]) / 1000000
self.total_cost += cost
logger.info(
f"Total running cost: ${self.total_cost:.4f}"
f"Current cost: ${cost:.4f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
)

View file

@ -4,7 +4,9 @@
@Time : 2023/12/19
@Author : mashenquan
@File : di_graph_repository.py
@Desc : Graph repository based on DiGraph
@Desc : Graph repository based on DiGraph.
This script defines a graph repository class based on a directed graph (DiGraph), providing functionalities
specific to handling directed relationships between entities.
"""
from __future__ import annotations
@ -19,20 +21,41 @@ from metagpt.utils.graph_repository import SPO, GraphRepository
class DiGraphRepository(GraphRepository):
"""Graph repository based on DiGraph."""
def __init__(self, name: str, **kwargs):
super().__init__(name=name, **kwargs)
self._repo = networkx.DiGraph()
async def insert(self, subject: str, predicate: str, object_: str):
"""Insert a new triple into the directed graph repository.
Args:
subject (str): The subject of the triple.
predicate (str): The predicate describing the relationship.
object_ (str): The object of the triple.
Example:
await my_di_graph_repo.insert(subject="Node1", predicate="connects_to", object_="Node2")
# Adds a directed relationship: Node1 connects_to Node2
"""
self._repo.add_edge(subject, object_, predicate=predicate)
async def upsert(self, subject: str, predicate: str, object_: str):
pass
async def update(self, subject: str, predicate: str, object_: str):
pass
async def select(self, subject: str = None, predicate: str = None, object_: str = None) -> List[SPO]:
"""Retrieve triples from the directed graph repository based on specified criteria.
Args:
subject (str, optional): The subject of the triple to filter by.
predicate (str, optional): The predicate describing the relationship to filter by.
object_ (str, optional): The object of the triple to filter by.
Returns:
List[SPO]: A list of SPO objects representing the selected triples.
Example:
selected_triples = await my_di_graph_repo.select(subject="Node1", predicate="connects_to")
# Retrieves directed relationships where Node1 is the subject and the predicate is 'connects_to'.
"""
result = []
for s, o, p in self._repo.edges(data="predicate"):
if subject and subject != s:
@ -44,12 +67,41 @@ class DiGraphRepository(GraphRepository):
result.append(SPO(subject=s, predicate=p, object_=o))
return result
async def delete(self, subject: str = None, predicate: str = None, object_: str = None) -> int:
"""Delete triples from the directed graph repository based on specified criteria.
Args:
subject (str, optional): The subject of the triple to filter by.
predicate (str, optional): The predicate describing the relationship to filter by.
object_ (str, optional): The object of the triple to filter by.
Returns:
int: The number of triples deleted from the repository.
Example:
deleted_count = await my_di_graph_repo.delete(subject="Node1", predicate="connects_to")
# Deletes directed relationships where Node1 is the subject and the predicate is 'connects_to'.
"""
rows = await self.select(subject=subject, predicate=predicate, object_=object_)
if not rows:
return 0
for r in rows:
self._repo.remove_edge(r.subject, r.object_)
return len(rows)
def json(self) -> str:
"""Convert the directed graph repository to a JSON-formatted string."""
m = networkx.node_link_data(self._repo)
data = json.dumps(m)
return data
async def save(self, path: str | Path = None):
"""Save the directed graph repository to a JSON file.
Args:
path (Union[str, Path], optional): The directory path where the JSON file will be saved.
If not provided, the default path is taken from the 'root' key in the keyword arguments.
"""
data = self.json()
path = path or self._kwargs.get("root")
if not path.exists():
@ -58,12 +110,21 @@ class DiGraphRepository(GraphRepository):
await awrite(filename=pathname.with_suffix(".json"), data=data, encoding="utf-8")
async def load(self, pathname: str | Path):
"""Load a directed graph repository from a JSON file."""
data = await aread(filename=pathname, encoding="utf-8")
m = json.loads(data)
self._repo = networkx.node_link_graph(m)
@staticmethod
async def load_from(pathname: str | Path) -> GraphRepository:
"""Create and load a directed graph repository from a JSON file.
Args:
pathname (Union[str, Path]): The path to the JSON file to be loaded.
Returns:
GraphRepository: A new instance of the graph repository loaded from the specified JSON file.
"""
pathname = Path(pathname)
name = pathname.with_suffix("").name
root = pathname.parent
@ -74,9 +135,16 @@ class DiGraphRepository(GraphRepository):
@property
def root(self) -> str:
"""Return the root directory path for the graph repository files."""
return self._kwargs.get("root")
@property
def pathname(self) -> Path:
"""Return the path and filename to the graph repository file."""
p = Path(self.root) / self.name
return p.with_suffix(".json")
@property
def repo(self):
"""Get the underlying directed graph repository."""
return self._repo

View file

@ -4,21 +4,28 @@
@Time : 2023/12/19
@Author : mashenquan
@File : graph_repository.py
@Desc : Superclass for graph repository.
@Desc : Superclass for graph repository. This script defines a superclass for a graph repository, providing a
foundation for specific implementations.
"""
from abc import ABC, abstractmethod
from collections import defaultdict
from pathlib import Path
from typing import List
from pydantic import BaseModel
from metagpt.logs import logger
from metagpt.repo_parser import ClassInfo, ClassRelationship, RepoFileInfo
from metagpt.utils.common import concat_namespace
from metagpt.repo_parser import DotClassInfo, DotClassRelationship, RepoFileInfo
from metagpt.utils.common import concat_namespace, split_namespace
class GraphKeyword:
"""Basic words for a Graph database.
This class defines a set of basic words commonly used in the context of a Graph database.
"""
IS = "is"
OF = "Of"
ON = "On"
@ -28,51 +35,149 @@ class GraphKeyword:
SOURCE_CODE = "source_code"
NULL = "<null>"
GLOBAL_VARIABLE = "global_variable"
CLASS_FUNCTION = "class_function"
CLASS_METHOD = "class_method"
CLASS_PROPERTY = "class_property"
HAS_CLASS_FUNCTION = "has_class_function"
HAS_CLASS_METHOD = "has_class_method"
HAS_CLASS_PROPERTY = "has_class_property"
HAS_CLASS = "has_class"
HAS_DETAIL = "has_detail"
HAS_PAGE_INFO = "has_page_info"
HAS_CLASS_VIEW = "has_class_view"
HAS_SEQUENCE_VIEW = "has_sequence_view"
HAS_ARGS_DESC = "has_args_desc"
HAS_TYPE_DESC = "has_type_desc"
HAS_SEQUENCE_VIEW_VER = "has_sequence_view_ver"
HAS_CLASS_USE_CASE = "has_class_use_case"
IS_COMPOSITE_OF = "is_composite_of"
IS_AGGREGATE_OF = "is_aggregate_of"
HAS_PARTICIPANT = "has_participant"
class SPO(BaseModel):
"""Graph repository record type.
This class represents a record in a graph repository with three components:
- Subject: The subject of the triple.
- Predicate: The predicate describing the relationship between the subject and the object.
- Object: The object of the triple.
Attributes:
subject (str): The subject of the triple.
predicate (str): The predicate describing the relationship.
object_ (str): The object of the triple.
Example:
spo_record = SPO(subject="Node1", predicate="connects_to", object_="Node2")
# Represents a triple: Node1 connects_to Node2
"""
subject: str
predicate: str
object_: str
class GraphRepository(ABC):
"""Abstract base class for a Graph Repository.
This class defines the interface for a graph repository, providing methods for inserting, selecting,
deleting, and saving graph data. Concrete implementations of this class must provide functionality
for these operations.
"""
def __init__(self, name: str, **kwargs):
self._repo_name = name
self._kwargs = kwargs
@abstractmethod
async def insert(self, subject: str, predicate: str, object_: str):
pass
"""Insert a new triple into the graph repository.
@abstractmethod
async def upsert(self, subject: str, predicate: str, object_: str):
pass
Args:
subject (str): The subject of the triple.
predicate (str): The predicate describing the relationship.
object_ (str): The object of the triple.
@abstractmethod
async def update(self, subject: str, predicate: str, object_: str):
Example:
await my_repository.insert(subject="Node1", predicate="connects_to", object_="Node2")
# Inserts a triple: Node1 connects_to Node2 into the graph repository.
"""
pass
@abstractmethod
async def select(self, subject: str = None, predicate: str = None, object_: str = None) -> List[SPO]:
"""Retrieve triples from the graph repository based on specified criteria.
Args:
subject (str, optional): The subject of the triple to filter by.
predicate (str, optional): The predicate describing the relationship to filter by.
object_ (str, optional): The object of the triple to filter by.
Returns:
List[SPO]: A list of SPO objects representing the selected triples.
Example:
selected_triples = await my_repository.select(subject="Node1", predicate="connects_to")
# Retrieves triples where Node1 is the subject and the predicate is 'connects_to'.
"""
pass
@abstractmethod
async def delete(self, subject: str = None, predicate: str = None, object_: str = None) -> int:
"""Delete triples from the graph repository based on specified criteria.
Args:
subject (str, optional): The subject of the triple to filter by.
predicate (str, optional): The predicate describing the relationship to filter by.
object_ (str, optional): The object of the triple to filter by.
Returns:
int: The number of triples deleted from the repository.
Example:
deleted_count = await my_repository.delete(subject="Node1", predicate="connects_to")
# Deletes triples where Node1 is the subject and the predicate is 'connects_to'.
"""
pass
@abstractmethod
async def save(self):
"""Save any changes made to the graph repository.
Example:
await my_repository.save()
# Persists any changes made to the graph repository.
"""
pass
@property
def name(self) -> str:
"""Get the name of the graph repository."""
return self._repo_name
@staticmethod
async def update_graph_db_with_file_info(graph_db: "GraphRepository", file_info: RepoFileInfo):
"""Insert information of RepoFileInfo into the specified graph repository.
This function updates the provided graph repository with information from the given RepoFileInfo object.
The function inserts triples related to various dimensions such as file type, class, class method, function,
global variable, and page info.
Triple Patterns:
- (?, is, [file type])
- (?, has class, ?)
- (?, is, [class])
- (?, has class method, ?)
- (?, has function, ?)
- (?, is, [function])
- (?, is, global variable)
- (?, has page info, ?)
Args:
graph_db (GraphRepository): The graph repository object to be updated.
file_info (RepoFileInfo): The RepoFileInfo object containing information to be inserted.
Example:
await update_graph_db_with_file_info(my_graph_repo, my_file_info)
# Updates 'my_graph_repo' with information from 'my_file_info'.
"""
await graph_db.insert(subject=file_info.file, predicate=GraphKeyword.IS, object_=GraphKeyword.SOURCE_CODE)
file_types = {".py": "python", ".js": "javascript"}
file_type = file_types.get(Path(file_info.file).suffix, GraphKeyword.NULL)
@ -95,13 +200,13 @@ class GraphRepository(ABC):
for fn in methods:
await graph_db.insert(
subject=concat_namespace(file_info.file, class_name),
predicate=GraphKeyword.HAS_CLASS_FUNCTION,
predicate=GraphKeyword.HAS_CLASS_METHOD,
object_=concat_namespace(file_info.file, class_name, fn),
)
await graph_db.insert(
subject=concat_namespace(file_info.file, class_name, fn),
predicate=GraphKeyword.IS,
object_=GraphKeyword.CLASS_FUNCTION,
object_=GraphKeyword.CLASS_METHOD,
)
for f in file_info.functions:
# file -> function
@ -133,7 +238,34 @@ class GraphRepository(ABC):
)
@staticmethod
async def update_graph_db_with_class_views(graph_db: "GraphRepository", class_views: List[ClassInfo]):
async def update_graph_db_with_class_views(graph_db: "GraphRepository", class_views: List[DotClassInfo]):
"""Insert dot format class information into the specified graph repository.
This function updates the provided graph repository with class information from the given list of DotClassInfo objects.
The function inserts triples related to various aspects of class views, including source code, file type, class,
class property, class detail, method, composition, and aggregation.
Triple Patterns:
- (?, is, source code)
- (?, is, file type)
- (?, has class, ?)
- (?, is, class)
- (?, has class property, ?)
- (?, is, class property)
- (?, has detail, ?)
- (?, has method, ?)
- (?, is composite of, ?)
- (?, is aggregate of, ?)
Args:
graph_db (GraphRepository): The graph repository object to be updated.
class_views (List[DotClassInfo]): List of DotClassInfo objects containing class information to be inserted.
Example:
await update_graph_db_with_class_views(my_graph_repo, [class_info1, class_info2])
# Updates 'my_graph_repo' with class information from the provided list of DotClassInfo objects.
"""
for c in class_views:
filename, _ = c.package.split(":", 1)
await graph_db.insert(subject=filename, predicate=GraphKeyword.IS, object_=GraphKeyword.SOURCE_CODE)
@ -146,6 +278,7 @@ class GraphRepository(ABC):
predicate=GraphKeyword.IS,
object_=GraphKeyword.CLASS,
)
await graph_db.insert(subject=c.package, predicate=GraphKeyword.HAS_DETAIL, object_=c.model_dump_json())
for vn, vt in c.attributes.items():
# class -> property
await graph_db.insert(
@ -160,33 +293,61 @@ class GraphRepository(ABC):
object_=GraphKeyword.CLASS_PROPERTY,
)
await graph_db.insert(
subject=concat_namespace(c.package, vn), predicate=GraphKeyword.HAS_TYPE_DESC, object_=vt
subject=concat_namespace(c.package, vn),
predicate=GraphKeyword.HAS_DETAIL,
object_=vt.model_dump_json(),
)
for fn, desc in c.methods.items():
if "</I>" in desc and "<I>" not in desc:
logger.error(desc)
for fn, ft in c.methods.items():
# class -> function
await graph_db.insert(
subject=c.package,
predicate=GraphKeyword.HAS_CLASS_FUNCTION,
predicate=GraphKeyword.HAS_CLASS_METHOD,
object_=concat_namespace(c.package, fn),
)
# function detail
await graph_db.insert(
subject=concat_namespace(c.package, fn),
predicate=GraphKeyword.IS,
object_=GraphKeyword.CLASS_FUNCTION,
object_=GraphKeyword.CLASS_METHOD,
)
await graph_db.insert(
subject=concat_namespace(c.package, fn),
predicate=GraphKeyword.HAS_ARGS_DESC,
object_=desc,
predicate=GraphKeyword.HAS_DETAIL,
object_=ft.model_dump_json(),
)
for i in c.compositions:
await graph_db.insert(
subject=c.package, predicate=GraphKeyword.IS_COMPOSITE_OF, object_=concat_namespace("?", i)
)
for i in c.aggregations:
await graph_db.insert(
subject=c.package, predicate=GraphKeyword.IS_AGGREGATE_OF, object_=concat_namespace("?", i)
)
@staticmethod
async def update_graph_db_with_class_relationship_views(
graph_db: "GraphRepository", relationship_views: List[ClassRelationship]
graph_db: "GraphRepository", relationship_views: List[DotClassRelationship]
):
"""Insert class relationships and labels into the specified graph repository.
This function updates the provided graph repository with class relationship information from the given list
of DotClassRelationship objects. The function inserts triples representing relationships and labels between
classes.
Triple Patterns:
- (?, is relationship of, ?)
- (?, is relationship on, ?)
Args:
graph_db (GraphRepository): The graph repository object to be updated.
relationship_views (List[DotClassRelationship]): List of DotClassRelationship objects containing
class relationship information to be inserted.
Example:
await update_graph_db_with_class_relationship_views(my_graph_repo, [relationship1, relationship2])
# Updates 'my_graph_repo' with class relationship information from the provided list of DotClassRelationship objects.
"""
for r in relationship_views:
await graph_db.insert(
subject=r.src, predicate=GraphKeyword.IS + r.relationship + GraphKeyword.OF, object_=r.dest
@ -198,3 +359,32 @@ class GraphRepository(ABC):
predicate=GraphKeyword.IS + r.relationship + GraphKeyword.ON,
object_=concat_namespace(r.dest, r.label),
)
@staticmethod
async def rebuild_composition_relationship(graph_db: "GraphRepository"):
"""Append namespace-prefixed information to relationship SPO (Subject-Predicate-Object) objects in the graph
repository.
This function updates the provided graph repository by appending namespace-prefixed information to existing
relationship SPO objects.
Args:
graph_db (GraphRepository): The graph repository object to be updated.
"""
classes = await graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
mapping = defaultdict(list)
for c in classes:
name = split_namespace(c.subject)[-1]
mapping[name].append(c.subject)
rows = await graph_db.select(predicate=GraphKeyword.IS_COMPOSITE_OF)
for r in rows:
ns, class_ = split_namespace(r.object_)
if ns != "?":
continue
val = mapping[class_]
if len(val) != 1:
continue
ns_name = val[0]
await graph_db.delete(subject=r.subject, predicate=r.predicate, object_=r.object_)
await graph_db.insert(subject=r.subject, predicate=r.predicate, object_=ns_name)

View file

@ -33,6 +33,7 @@ from metagpt.const import (
TASK_PDF_FILE_REPO,
TEST_CODES_FILE_REPO,
TEST_OUTPUTS_FILE_REPO,
VISUAL_GRAPH_REPO_FILE_REPO,
)
from metagpt.utils.file_repository import FileRepository
from metagpt.utils.git_repository import GitRepository
@ -69,6 +70,7 @@ class ResourceFileRepositories(FileRepository):
code_summary: FileRepository
sd_output: FileRepository
code_plan_and_change: FileRepository
graph_repo: FileRepository
def __init__(self, git_repo):
super().__init__(git_repo=git_repo, relative_path=RESOURCES_FILE_REPO)
@ -82,6 +84,7 @@ class ResourceFileRepositories(FileRepository):
self.code_summary = git_repo.new_file_repository(relative_path=CODE_SUMMARIES_PDF_FILE_REPO)
self.sd_output = git_repo.new_file_repository(relative_path=SD_OUTPUT_FILE_REPO)
self.code_plan_and_change = git_repo.new_file_repository(relative_path=CODE_PLAN_AND_CHANGE_PDF_FILE_REPO)
self.graph_repo = git_repo.new_file_repository(relative_path=VISUAL_GRAPH_REPO_FILE_REPO)
class ProjectRepo(FileRepository):
@ -133,6 +136,7 @@ class ProjectRepo(FileRepository):
code_files = self.with_src_path(path=git_workdir / git_workdir.name).srcs.all_files
if not code_files:
return False
return bool(code_files)
def with_src_path(self, path: str | Path) -> ProjectRepo:
try:

View file

@ -119,6 +119,7 @@ def repair_json_format(output: str) -> str:
logger.info(f"repair_json_format: {'}]'}")
elif output.startswith("{") and output.endswith("]"):
output = output[:-1] + "}"
# remove comments in output json string, after json value content, maybe start with #, maybe start with //
arr = output.split("\n")
new_arr = []
@ -208,6 +209,17 @@ def repair_invalid_json(output: str, error: str) -> str:
elif (rline[col_no] in ["'", '"']) and (line.startswith('"') or line.startswith("'")) and "," not in line:
# problem, `"""` or `'''` without `,`
new_line = f",{line}"
elif col_no - 1 >= 0 and rline[col_no - 1] in ['"', "'"]:
# backslash problem like \" in the output
char = rline[col_no - 1]
nearest_char_idx = rline[col_no:].find(char)
new_line = (
rline[: col_no - 1]
+ "\\"
+ rline[col_no - 1 : col_no + nearest_char_idx]
+ "\\"
+ rline[col_no + nearest_char_idx :]
)
elif '",' not in line and "," not in line and '"' not in line:
new_line = f'{line}",'
elif not line.endswith(","):

View file

@ -35,9 +35,111 @@ TOKEN_COSTS = {
"glm-3-turbo": {"prompt": 0.0007, "completion": 0.0007}, # 128k version, prompt + completion tokens=0.005¥/k-tokens
"glm-4": {"prompt": 0.014, "completion": 0.014}, # 128k version, prompt + completion tokens=0.1¥/k-tokens
"gemini-pro": {"prompt": 0.00025, "completion": 0.0005},
"moonshot-v1-8k": {"prompt": 0.012, "completion": 0.012}, # prompt + completion tokens=0.012¥/k-tokens
"moonshot-v1-32k": {"prompt": 0.024, "completion": 0.024},
"moonshot-v1-128k": {"prompt": 0.06, "completion": 0.06},
"open-mistral-7b": {"prompt": 0.00025, "completion": 0.00025},
"open-mixtral-8x7b": {"prompt": 0.0007, "completion": 0.0007},
"mistral-small-latest": {"prompt": 0.002, "completion": 0.006},
"mistral-medium-latest": {"prompt": 0.0027, "completion": 0.0081},
"mistral-large-latest": {"prompt": 0.008, "completion": 0.024},
"claude-instant-1.2": {"prompt": 0.0008, "completion": 0.0024},
"claude-2.0": {"prompt": 0.008, "completion": 0.024},
"claude-2.1": {"prompt": 0.008, "completion": 0.024},
"claude-3-sonnet-20240229": {"prompt": 0.003, "completion": 0.015},
"claude-3-opus-20240229": {"prompt": 0.015, "completion": 0.075},
}
"""
QianFan Token Price https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7#tokens%E5%90%8E%E4%BB%98%E8%B4%B9
Due to QianFan has multi price strategies, we unify `Tokens post-payment` as a statistical method.
"""
QIANFAN_MODEL_TOKEN_COSTS = {
"ERNIE-Bot-4": {"prompt": 0.017, "completion": 0.017},
"ERNIE-Bot-8k": {"prompt": 0.0034, "completion": 0.0067},
"ERNIE-Bot": {"prompt": 0.0017, "completion": 0.0017},
"ERNIE-Bot-turbo": {"prompt": 0.0011, "completion": 0.0011},
"EB-turbo-AppBuilder": {"prompt": 0.0011, "completion": 0.0011},
"ERNIE-Speed": {"prompt": 0.00056, "completion": 0.0011},
"BLOOMZ-7B": {"prompt": 0.00056, "completion": 0.00056},
"Llama-2-7B-Chat": {"prompt": 0.00056, "completion": 0.00056},
"Llama-2-13B-Chat": {"prompt": 0.00084, "completion": 0.00084},
"Llama-2-70B-Chat": {"prompt": 0.0049, "completion": 0.0049},
"ChatGLM2-6B-32K": {"prompt": 0.00056, "completion": 0.00056},
"AquilaChat-7B": {"prompt": 0.00056, "completion": 0.00056},
"Mixtral-8x7B-Instruct": {"prompt": 0.0049, "completion": 0.0049},
"SQLCoder-7B": {"prompt": 0.00056, "completion": 0.00056},
"CodeLlama-7B-Instruct": {"prompt": 0.00056, "completion": 0.00056},
"XuanYuan-70B-Chat-4bit": {"prompt": 0.0049, "completion": 0.0049},
"Qianfan-BLOOMZ-7B-compressed": {"prompt": 0.00056, "completion": 0.00056},
"Qianfan-Chinese-Llama-2-7B": {"prompt": 0.00056, "completion": 0.00056},
"Qianfan-Chinese-Llama-2-13B": {"prompt": 0.00084, "completion": 0.00084},
"ChatLaw": {"prompt": 0.0011, "completion": 0.0011},
"Yi-34B-Chat": {"prompt": 0.0, "completion": 0.0},
}
QIANFAN_ENDPOINT_TOKEN_COSTS = {
"completions_pro": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Bot-4"],
"ernie_bot_8k": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Bot-8k"],
"completions": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Bot"],
"eb-instant": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Bot-turbo"],
"ai_apaas": QIANFAN_MODEL_TOKEN_COSTS["EB-turbo-AppBuilder"],
"ernie_speed": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Speed"],
"bloomz_7b1": QIANFAN_MODEL_TOKEN_COSTS["BLOOMZ-7B"],
"llama_2_7b": QIANFAN_MODEL_TOKEN_COSTS["Llama-2-7B-Chat"],
"llama_2_13b": QIANFAN_MODEL_TOKEN_COSTS["Llama-2-13B-Chat"],
"llama_2_70b": QIANFAN_MODEL_TOKEN_COSTS["Llama-2-70B-Chat"],
"chatglm2_6b_32k": QIANFAN_MODEL_TOKEN_COSTS["ChatGLM2-6B-32K"],
"aquilachat_7b": QIANFAN_MODEL_TOKEN_COSTS["AquilaChat-7B"],
"mixtral_8x7b_instruct": QIANFAN_MODEL_TOKEN_COSTS["Mixtral-8x7B-Instruct"],
"sqlcoder_7b": QIANFAN_MODEL_TOKEN_COSTS["SQLCoder-7B"],
"codellama_7b_instruct": QIANFAN_MODEL_TOKEN_COSTS["CodeLlama-7B-Instruct"],
"xuanyuan_70b_chat": QIANFAN_MODEL_TOKEN_COSTS["XuanYuan-70B-Chat-4bit"],
"qianfan_bloomz_7b_compressed": QIANFAN_MODEL_TOKEN_COSTS["Qianfan-BLOOMZ-7B-compressed"],
"qianfan_chinese_llama_2_7b": QIANFAN_MODEL_TOKEN_COSTS["Qianfan-Chinese-Llama-2-7B"],
"qianfan_chinese_llama_2_13b": QIANFAN_MODEL_TOKEN_COSTS["Qianfan-Chinese-Llama-2-13B"],
"chatlaw": QIANFAN_MODEL_TOKEN_COSTS["ChatLaw"],
"yi_34b_chat": QIANFAN_MODEL_TOKEN_COSTS["Yi-34B-Chat"],
}
"""
DashScope Token price https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing
Different model has different detail page. Attention, some model are free for a limited time.
"""
DASHSCOPE_TOKEN_COSTS = {
"qwen-turbo": {"prompt": 0.0011, "completion": 0.0011},
"qwen-plus": {"prompt": 0.0028, "completion": 0.0028},
"qwen-max": {"prompt": 0.0, "completion": 0.0},
"qwen-max-1201": {"prompt": 0.0, "completion": 0.0},
"qwen-max-longcontext": {"prompt": 0.0, "completion": 0.0},
"llama2-7b-chat-v2": {"prompt": 0.0, "completion": 0.0},
"llama2-13b-chat-v2": {"prompt": 0.0, "completion": 0.0},
"qwen-72b-chat": {"prompt": 0.0, "completion": 0.0},
"qwen-14b-chat": {"prompt": 0.0011, "completion": 0.0011},
"qwen-7b-chat": {"prompt": 0.00084, "completion": 0.00084},
"qwen-1.8b-chat": {"prompt": 0.0, "completion": 0.0},
"baichuan2-13b-chat-v1": {"prompt": 0.0011, "completion": 0.0011},
"baichuan2-7b-chat-v1": {"prompt": 0.00084, "completion": 0.00084},
"baichuan-7b-v1": {"prompt": 0.0, "completion": 0.0},
"chatglm-6b-v2": {"prompt": 0.0011, "completion": 0.0011},
"chatglm3-6b": {"prompt": 0.0, "completion": 0.0},
"ziya-llama-13b-v1": {"prompt": 0.0, "completion": 0.0}, # no price page, judge it as free
"dolly-12b-v2": {"prompt": 0.0, "completion": 0.0},
"belle-llama-13b-2m-v1": {"prompt": 0.0, "completion": 0.0},
"moss-moon-003-sft-v1": {"prompt": 0.0, "completion": 0.0},
"chatyuan-large-v2": {"prompt": 0.0, "completion": 0.0},
"billa-7b-sft-v1": {"prompt": 0.0, "completion": 0.0},
}
FIREWORKS_GRADE_TOKEN_COSTS = {
"-1": {"prompt": 0.0, "completion": 0.0}, # abnormal condition
"16": {"prompt": 0.2, "completion": 0.8}, # 16 means model size <= 16B; 0.2 means $0.2/1M tokens
"80": {"prompt": 0.7, "completion": 2.8}, # 80 means 16B < model size <= 80B
"mixtral-8x7b": {"prompt": 0.4, "completion": 1.6},
}
TOKEN_MAX = {
"gpt-3.5-turbo": 4096,
"gpt-3.5-turbo-0301": 4096,
@ -61,6 +163,19 @@ TOKEN_MAX = {
"glm-3-turbo": 128000,
"glm-4": 128000,
"gemini-pro": 32768,
"moonshot-v1-8k": 8192,
"moonshot-v1-32k": 32768,
"moonshot-v1-128k": 128000,
"open-mistral-7b": 8192,
"open-mixtral-8x7b": 32768,
"mistral-small-latest": 32768,
"mistral-medium-latest": 32768,
"mistral-large-latest": 32768,
"claude-instant-1.2": 100000,
"claude-2.0": 100000,
"claude-2.1": 200000,
"claude-3-sonnet-20240229": 200000,
"claude-3-opus-20240229": 200000,
}

View file

@ -0,0 +1,162 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/12/19
@Author : mashenquan
@File : visualize_graph.py
@Desc : Visualization tool to visualize the class diagrams or sequence diagrams of the graph repository.
"""
from __future__ import annotations
import re
from abc import ABC
from pathlib import Path
from typing import List, Optional
from pydantic import BaseModel, Field
from metagpt.const import AGGREGATION, COMPOSITION, GENERALIZATION
from metagpt.schema import UMLClassView
from metagpt.utils.common import split_namespace
from metagpt.utils.di_graph_repository import DiGraphRepository
from metagpt.utils.graph_repository import GraphKeyword, GraphRepository
class _VisualClassView(BaseModel):
"""Protected class used by VisualGraphRepo internally.
Attributes:
package (str): The package associated with the class.
uml (Optional[UMLClassView]): Optional UMLClassView associated with the class.
generalizations (List[str]): List of generalizations for the class.
compositions (List[str]): List of compositions for the class.
aggregations (List[str]): List of aggregations for the class.
"""
package: str
uml: Optional[UMLClassView] = None
generalizations: List[str] = Field(default_factory=list)
compositions: List[str] = Field(default_factory=list)
aggregations: List[str] = Field(default_factory=list)
def get_mermaid(self, align: int = 1) -> str:
"""Creates a Markdown Mermaid class diagram text.
Args:
align (int): Indent count used for alignment.
Returns:
str: The Markdown text representing the Mermaid class diagram.
"""
if not self.uml:
return ""
prefix = "\t" * align
mermaid_txt = self.uml.get_mermaid(align=align)
for i in self.generalizations:
mermaid_txt += f"{prefix}{i} <|-- {self.name}\n"
for i in self.compositions:
mermaid_txt += f"{prefix}{i} *-- {self.name}\n"
for i in self.aggregations:
mermaid_txt += f"{prefix}{i} o-- {self.name}\n"
return mermaid_txt
@property
def name(self) -> str:
"""Returns the class name without the namespace prefix."""
return split_namespace(self.package)[-1]
class VisualGraphRepo(ABC):
"""Abstract base class for VisualGraphRepo."""
graph_db: GraphRepository
def __init__(self, graph_db):
self.graph_db = graph_db
class VisualDiGraphRepo(VisualGraphRepo):
"""Implementation of VisualGraphRepo for DiGraph graph repository.
This class extends VisualGraphRepo to provide specific functionality for a graph repository using DiGraph.
"""
@classmethod
async def load_from(cls, filename: str | Path):
"""Load a VisualDiGraphRepo instance from a file."""
graph_db = await DiGraphRepository.load_from(str(filename))
return cls(graph_db=graph_db)
async def get_mermaid_class_view(self) -> str:
"""
Returns a Markdown Mermaid class diagram code block object.
"""
rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
mermaid_txt = "classDiagram\n"
for r in rows:
v = await self._get_class_view(ns_class_name=r.subject)
mermaid_txt += v.get_mermaid()
return mermaid_txt
async def _get_class_view(self, ns_class_name: str) -> _VisualClassView:
"""Returns the Markdown Mermaid class diagram code block object for the specified class."""
rows = await self.graph_db.select(subject=ns_class_name)
class_view = _VisualClassView(package=ns_class_name)
for r in rows:
if r.predicate == GraphKeyword.HAS_CLASS_VIEW:
class_view.uml = UMLClassView.model_validate_json(r.object_)
elif r.predicate == GraphKeyword.IS + GENERALIZATION + GraphKeyword.OF:
name = split_namespace(r.object_)[-1]
name = self._refine_name(name)
if name:
class_view.generalizations.append(name)
elif r.predicate == GraphKeyword.IS + COMPOSITION + GraphKeyword.OF:
name = split_namespace(r.object_)[-1]
name = self._refine_name(name)
if name:
class_view.compositions.append(name)
elif r.predicate == GraphKeyword.IS + AGGREGATION + GraphKeyword.OF:
name = split_namespace(r.object_)[-1]
name = self._refine_name(name)
if name:
class_view.aggregations.append(name)
return class_view
async def get_mermaid_sequence_views(self) -> List[(str, str)]:
"""Returns all Markdown sequence diagrams with their corresponding graph repository keys."""
sequence_views = []
rows = await self.graph_db.select(predicate=GraphKeyword.HAS_SEQUENCE_VIEW)
for r in rows:
sequence_views.append((r.subject, r.object_))
return sequence_views
@staticmethod
def _refine_name(name: str) -> str:
"""Removes impurity content from the given name.
Example:
>>> _refine_name("int")
""
>>> _refine_name('"Class1"')
'Class1'
>>> _refine_name("pkg.Class1")
"Class1"
"""
name = re.sub(r'^[\'"\\\(\)]+|[\'"\\\(\)]+$', "", name)
if name in ["int", "float", "bool", "str", "list", "tuple", "set", "dict", "None"]:
return ""
if "." in name:
name = name.split(".")[-1]
return name
async def get_mermaid_sequence_view_versions(self) -> List[(str, str)]:
"""Returns all versioned Markdown sequence diagrams with their corresponding graph repository keys."""
sequence_views = []
rows = await self.graph_db.select(predicate=GraphKeyword.HAS_SEQUENCE_VIEW_VER)
for r in rows:
sequence_views.append((r.subject, r.object_))
return sequence_views