mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-06 14:22:46 +02:00
Merge branch 'main' into code_interpreter
This commit is contained in:
commit
38f21137ec
146 changed files with 4466 additions and 1375 deletions
|
|
@ -23,10 +23,10 @@ import platform
|
|||
import re
|
||||
import sys
|
||||
import traceback
|
||||
import typing
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, List, Tuple, Union
|
||||
from typing import Any, Callable, List, Literal, Tuple, Union
|
||||
from urllib.parse import quote, unquote
|
||||
|
||||
import aiofiles
|
||||
import loguru
|
||||
|
|
@ -423,23 +423,109 @@ def is_send_to(message: "Message", addresses: set):
|
|||
def any_to_name(val):
|
||||
"""
|
||||
Convert a value to its name by extracting the last part of the dotted path.
|
||||
|
||||
:param val: The value to convert.
|
||||
|
||||
:return: The name of the value.
|
||||
"""
|
||||
return any_to_str(val).split(".")[-1]
|
||||
|
||||
|
||||
def concat_namespace(*args) -> str:
|
||||
return ":".join(str(value) for value in args)
|
||||
def concat_namespace(*args, delimiter: str = ":") -> str:
|
||||
"""Concatenate fields to create a unique namespace prefix.
|
||||
|
||||
Example:
|
||||
>>> concat_namespace('prefix', 'field1', 'field2', delimiter=":")
|
||||
'prefix:field1:field2'
|
||||
"""
|
||||
return delimiter.join(str(value) for value in args)
|
||||
|
||||
|
||||
def split_namespace(ns_class_name: str) -> List[str]:
|
||||
return ns_class_name.split(":")
|
||||
def split_namespace(ns_class_name: str, delimiter: str = ":", maxsplit: int = 1) -> List[str]:
|
||||
"""Split a namespace-prefixed name into its namespace-prefix and name parts.
|
||||
|
||||
Example:
|
||||
>>> split_namespace('prefix:classname')
|
||||
['prefix', 'classname']
|
||||
|
||||
>>> split_namespace('prefix:module:class', delimiter=":", maxsplit=2)
|
||||
['prefix', 'module', 'class']
|
||||
"""
|
||||
return ns_class_name.split(delimiter, maxsplit=maxsplit)
|
||||
|
||||
|
||||
def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> typing.Callable[["RetryCallState"], None]:
|
||||
def auto_namespace(name: str, delimiter: str = ":") -> str:
|
||||
"""Automatically handle namespace-prefixed names.
|
||||
|
||||
If the input name is empty, returns a default namespace prefix and name.
|
||||
If the input name is not namespace-prefixed, adds a default namespace prefix.
|
||||
Otherwise, returns the input name unchanged.
|
||||
|
||||
Example:
|
||||
>>> auto_namespace('classname')
|
||||
'?:classname'
|
||||
|
||||
>>> auto_namespace('prefix:classname')
|
||||
'prefix:classname'
|
||||
|
||||
>>> auto_namespace('')
|
||||
'?:?'
|
||||
|
||||
>>> auto_namespace('?:custom')
|
||||
'?:custom'
|
||||
"""
|
||||
if not name:
|
||||
return f"?{delimiter}?"
|
||||
v = split_namespace(name, delimiter=delimiter)
|
||||
if len(v) < 2:
|
||||
return f"?{delimiter}{name}"
|
||||
return name
|
||||
|
||||
|
||||
def add_affix(text: str, affix: Literal["brace", "url", "none"] = "brace"):
|
||||
"""Add affix to encapsulate data.
|
||||
|
||||
Example:
|
||||
>>> add_affix("data", affix="brace")
|
||||
'{data}'
|
||||
|
||||
>>> add_affix("example.com", affix="url")
|
||||
'%7Bexample.com%7D'
|
||||
|
||||
>>> add_affix("text", affix="none")
|
||||
'text'
|
||||
"""
|
||||
mappings = {
|
||||
"brace": lambda x: "{" + x + "}",
|
||||
"url": lambda x: quote("{" + x + "}"),
|
||||
}
|
||||
encoder = mappings.get(affix, lambda x: x)
|
||||
return encoder(text)
|
||||
|
||||
|
||||
def remove_affix(text, affix: Literal["brace", "url", "none"] = "brace"):
|
||||
"""Remove affix to extract encapsulated data.
|
||||
|
||||
Args:
|
||||
text (str): The input text with affix to be removed.
|
||||
affix (str, optional): The type of affix used. Defaults to "brace".
|
||||
Supported affix types: "brace" for removing curly braces, "url" for URL decoding within curly braces.
|
||||
|
||||
Returns:
|
||||
str: The text with affix removed.
|
||||
|
||||
Example:
|
||||
>>> remove_affix('{data}', affix="brace")
|
||||
'data'
|
||||
|
||||
>>> remove_affix('%7Bexample.com%7D', affix="url")
|
||||
'example.com'
|
||||
|
||||
>>> remove_affix('text', affix="none")
|
||||
'text'
|
||||
"""
|
||||
mappings = {"brace": lambda x: x[1:-1], "url": lambda x: unquote(x)[1:-1]}
|
||||
decoder = mappings.get(affix, lambda x: x)
|
||||
return decoder(text)
|
||||
|
||||
|
||||
def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> Callable[["RetryCallState"], None]:
|
||||
"""
|
||||
Generates a logging function to be used after a call is retried.
|
||||
|
||||
|
|
@ -626,6 +712,54 @@ def list_files(root: str | Path) -> List[Path]:
|
|||
return files
|
||||
|
||||
|
||||
def parse_json_code_block(markdown_text: str) -> List[str]:
|
||||
json_blocks = re.findall(r"```json(.*?)```", markdown_text, re.DOTALL)
|
||||
return [v.strip() for v in json_blocks]
|
||||
|
||||
|
||||
def remove_white_spaces(v: str) -> str:
|
||||
return re.sub(r"(?<!['\"])\s|(?<=['\"])\s", "", v)
|
||||
|
||||
|
||||
async def aread_bin(filename: str | Path) -> bytes:
|
||||
"""Read binary file asynchronously.
|
||||
|
||||
Args:
|
||||
filename (Union[str, Path]): The name or path of the file to be read.
|
||||
|
||||
Returns:
|
||||
bytes: The content of the file as bytes.
|
||||
|
||||
Example:
|
||||
>>> content = await aread_bin('example.txt')
|
||||
b'This is the content of the file.'
|
||||
|
||||
>>> content = await aread_bin(Path('example.txt'))
|
||||
b'This is the content of the file.'
|
||||
"""
|
||||
async with aiofiles.open(str(filename), mode="rb") as reader:
|
||||
content = await reader.read()
|
||||
return content
|
||||
|
||||
|
||||
async def awrite_bin(filename: str | Path, data: bytes):
|
||||
"""Write binary file asynchronously.
|
||||
|
||||
Args:
|
||||
filename (Union[str, Path]): The name or path of the file to be written.
|
||||
data (bytes): The binary data to be written to the file.
|
||||
|
||||
Example:
|
||||
>>> await awrite_bin('output.bin', b'This is binary data.')
|
||||
|
||||
>>> await awrite_bin(Path('output.bin'), b'Another set of binary data.')
|
||||
"""
|
||||
pathname = Path(filename)
|
||||
pathname.parent.mkdir(parents=True, exist_ok=True)
|
||||
async with aiofiles.open(str(pathname), mode="wb") as writer:
|
||||
await writer.write(data)
|
||||
|
||||
|
||||
def is_coroutine_func(func: Callable) -> bool:
|
||||
return inspect.iscoroutinefunction(func)
|
||||
|
||||
|
|
@ -689,3 +823,14 @@ def process_message(messages: Union[str, Message, list[dict], list[Message], lis
|
|||
else:
|
||||
raise ValueError(f"Only support message type are: str, Message, dict, but got {type(messages).__name__}!")
|
||||
return processed_messages
|
||||
|
||||
|
||||
def log_and_reraise(retry_state: RetryCallState):
|
||||
logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}")
|
||||
logger.warning(
|
||||
"""
|
||||
Recommend going to https://deepwisdom.feishu.cn/wiki/MsGnwQBjiif9c3koSJNcYaoSnu4#part-XdatdVlhEojeAfxaaEZcMV3ZniQ
|
||||
See FAQ 5.8
|
||||
"""
|
||||
)
|
||||
raise retry_state.outcome.exception()
|
||||
|
|
|
|||
|
|
@ -6,12 +6,13 @@
|
|||
@Desc : mashenquan, 2023/8/28. Separate the `CostManager` class to support user-level cost accounting.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import NamedTuple
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.token_counter import TOKEN_COSTS
|
||||
from metagpt.utils.token_counter import FIREWORKS_GRADE_TOKEN_COSTS, TOKEN_COSTS
|
||||
|
||||
|
||||
class Costs(NamedTuple):
|
||||
|
|
@ -29,6 +30,7 @@ class CostManager(BaseModel):
|
|||
total_budget: float = 0
|
||||
max_budget: float = 10.0
|
||||
total_cost: float = 0
|
||||
token_costs: dict[str, dict[str, float]] = TOKEN_COSTS # different model's token cost
|
||||
|
||||
def update_cost(self, prompt_tokens, completion_tokens, model):
|
||||
"""
|
||||
|
|
@ -39,14 +41,17 @@ class CostManager(BaseModel):
|
|||
completion_tokens (int): The number of tokens used in the completion.
|
||||
model (str): The model used for the API call.
|
||||
"""
|
||||
if prompt_tokens + completion_tokens == 0 or not model:
|
||||
return
|
||||
self.total_prompt_tokens += prompt_tokens
|
||||
self.total_completion_tokens += completion_tokens
|
||||
if model not in TOKEN_COSTS:
|
||||
if model not in self.token_costs:
|
||||
logger.warning(f"Model {model} not found in TOKEN_COSTS.")
|
||||
return
|
||||
|
||||
cost = (
|
||||
prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model]["completion"]
|
||||
prompt_tokens * self.token_costs[model]["prompt"]
|
||||
+ completion_tokens * self.token_costs[model]["completion"]
|
||||
) / 1000
|
||||
self.total_cost += cost
|
||||
logger.info(
|
||||
|
|
@ -101,3 +106,44 @@ class TokenCostManager(CostManager):
|
|||
self.total_prompt_tokens += prompt_tokens
|
||||
self.total_completion_tokens += completion_tokens
|
||||
logger.info(f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}")
|
||||
|
||||
|
||||
class FireworksCostManager(CostManager):
|
||||
def model_grade_token_costs(self, model: str) -> dict[str, float]:
|
||||
def _get_model_size(model: str) -> float:
|
||||
size = re.findall(".*-([0-9.]+)b", model)
|
||||
size = float(size[0]) if len(size) > 0 else -1
|
||||
return size
|
||||
|
||||
if "mixtral-8x7b" in model:
|
||||
token_costs = FIREWORKS_GRADE_TOKEN_COSTS["mixtral-8x7b"]
|
||||
else:
|
||||
model_size = _get_model_size(model)
|
||||
if 0 < model_size <= 16:
|
||||
token_costs = FIREWORKS_GRADE_TOKEN_COSTS["16"]
|
||||
elif 16 < model_size <= 80:
|
||||
token_costs = FIREWORKS_GRADE_TOKEN_COSTS["80"]
|
||||
else:
|
||||
token_costs = FIREWORKS_GRADE_TOKEN_COSTS["-1"]
|
||||
return token_costs
|
||||
|
||||
def update_cost(self, prompt_tokens: int, completion_tokens: int, model: str):
|
||||
"""
|
||||
Refs to `https://app.fireworks.ai/pricing` **Developer pricing**
|
||||
Update the total cost, prompt tokens, and completion tokens.
|
||||
|
||||
Args:
|
||||
prompt_tokens (int): The number of tokens used in the prompt.
|
||||
completion_tokens (int): The number of tokens used in the completion.
|
||||
model (str): The model used for the API call.
|
||||
"""
|
||||
self.total_prompt_tokens += prompt_tokens
|
||||
self.total_completion_tokens += completion_tokens
|
||||
|
||||
token_costs = self.model_grade_token_costs(model)
|
||||
cost = (prompt_tokens * token_costs["prompt"] + completion_tokens * token_costs["completion"]) / 1000000
|
||||
self.total_cost += cost
|
||||
logger.info(
|
||||
f"Total running cost: ${self.total_cost:.4f}"
|
||||
f"Current cost: ${cost:.4f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,9 @@
|
|||
@Time : 2023/12/19
|
||||
@Author : mashenquan
|
||||
@File : di_graph_repository.py
|
||||
@Desc : Graph repository based on DiGraph
|
||||
@Desc : Graph repository based on DiGraph.
|
||||
This script defines a graph repository class based on a directed graph (DiGraph), providing functionalities
|
||||
specific to handling directed relationships between entities.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
|
|
@ -19,20 +21,41 @@ from metagpt.utils.graph_repository import SPO, GraphRepository
|
|||
|
||||
|
||||
class DiGraphRepository(GraphRepository):
|
||||
"""Graph repository based on DiGraph."""
|
||||
|
||||
def __init__(self, name: str, **kwargs):
|
||||
super().__init__(name=name, **kwargs)
|
||||
self._repo = networkx.DiGraph()
|
||||
|
||||
async def insert(self, subject: str, predicate: str, object_: str):
|
||||
"""Insert a new triple into the directed graph repository.
|
||||
|
||||
Args:
|
||||
subject (str): The subject of the triple.
|
||||
predicate (str): The predicate describing the relationship.
|
||||
object_ (str): The object of the triple.
|
||||
|
||||
Example:
|
||||
await my_di_graph_repo.insert(subject="Node1", predicate="connects_to", object_="Node2")
|
||||
# Adds a directed relationship: Node1 connects_to Node2
|
||||
"""
|
||||
self._repo.add_edge(subject, object_, predicate=predicate)
|
||||
|
||||
async def upsert(self, subject: str, predicate: str, object_: str):
|
||||
pass
|
||||
|
||||
async def update(self, subject: str, predicate: str, object_: str):
|
||||
pass
|
||||
|
||||
async def select(self, subject: str = None, predicate: str = None, object_: str = None) -> List[SPO]:
|
||||
"""Retrieve triples from the directed graph repository based on specified criteria.
|
||||
|
||||
Args:
|
||||
subject (str, optional): The subject of the triple to filter by.
|
||||
predicate (str, optional): The predicate describing the relationship to filter by.
|
||||
object_ (str, optional): The object of the triple to filter by.
|
||||
|
||||
Returns:
|
||||
List[SPO]: A list of SPO objects representing the selected triples.
|
||||
|
||||
Example:
|
||||
selected_triples = await my_di_graph_repo.select(subject="Node1", predicate="connects_to")
|
||||
# Retrieves directed relationships where Node1 is the subject and the predicate is 'connects_to'.
|
||||
"""
|
||||
result = []
|
||||
for s, o, p in self._repo.edges(data="predicate"):
|
||||
if subject and subject != s:
|
||||
|
|
@ -44,12 +67,41 @@ class DiGraphRepository(GraphRepository):
|
|||
result.append(SPO(subject=s, predicate=p, object_=o))
|
||||
return result
|
||||
|
||||
async def delete(self, subject: str = None, predicate: str = None, object_: str = None) -> int:
|
||||
"""Delete triples from the directed graph repository based on specified criteria.
|
||||
|
||||
Args:
|
||||
subject (str, optional): The subject of the triple to filter by.
|
||||
predicate (str, optional): The predicate describing the relationship to filter by.
|
||||
object_ (str, optional): The object of the triple to filter by.
|
||||
|
||||
Returns:
|
||||
int: The number of triples deleted from the repository.
|
||||
|
||||
Example:
|
||||
deleted_count = await my_di_graph_repo.delete(subject="Node1", predicate="connects_to")
|
||||
# Deletes directed relationships where Node1 is the subject and the predicate is 'connects_to'.
|
||||
"""
|
||||
rows = await self.select(subject=subject, predicate=predicate, object_=object_)
|
||||
if not rows:
|
||||
return 0
|
||||
for r in rows:
|
||||
self._repo.remove_edge(r.subject, r.object_)
|
||||
return len(rows)
|
||||
|
||||
def json(self) -> str:
|
||||
"""Convert the directed graph repository to a JSON-formatted string."""
|
||||
m = networkx.node_link_data(self._repo)
|
||||
data = json.dumps(m)
|
||||
return data
|
||||
|
||||
async def save(self, path: str | Path = None):
|
||||
"""Save the directed graph repository to a JSON file.
|
||||
|
||||
Args:
|
||||
path (Union[str, Path], optional): The directory path where the JSON file will be saved.
|
||||
If not provided, the default path is taken from the 'root' key in the keyword arguments.
|
||||
"""
|
||||
data = self.json()
|
||||
path = path or self._kwargs.get("root")
|
||||
if not path.exists():
|
||||
|
|
@ -58,12 +110,21 @@ class DiGraphRepository(GraphRepository):
|
|||
await awrite(filename=pathname.with_suffix(".json"), data=data, encoding="utf-8")
|
||||
|
||||
async def load(self, pathname: str | Path):
|
||||
"""Load a directed graph repository from a JSON file."""
|
||||
data = await aread(filename=pathname, encoding="utf-8")
|
||||
m = json.loads(data)
|
||||
self._repo = networkx.node_link_graph(m)
|
||||
|
||||
@staticmethod
|
||||
async def load_from(pathname: str | Path) -> GraphRepository:
|
||||
"""Create and load a directed graph repository from a JSON file.
|
||||
|
||||
Args:
|
||||
pathname (Union[str, Path]): The path to the JSON file to be loaded.
|
||||
|
||||
Returns:
|
||||
GraphRepository: A new instance of the graph repository loaded from the specified JSON file.
|
||||
"""
|
||||
pathname = Path(pathname)
|
||||
name = pathname.with_suffix("").name
|
||||
root = pathname.parent
|
||||
|
|
@ -74,9 +135,16 @@ class DiGraphRepository(GraphRepository):
|
|||
|
||||
@property
|
||||
def root(self) -> str:
|
||||
"""Return the root directory path for the graph repository files."""
|
||||
return self._kwargs.get("root")
|
||||
|
||||
@property
|
||||
def pathname(self) -> Path:
|
||||
"""Return the path and filename to the graph repository file."""
|
||||
p = Path(self.root) / self.name
|
||||
return p.with_suffix(".json")
|
||||
|
||||
@property
|
||||
def repo(self):
|
||||
"""Get the underlying directed graph repository."""
|
||||
return self._repo
|
||||
|
|
|
|||
|
|
@ -4,21 +4,28 @@
|
|||
@Time : 2023/12/19
|
||||
@Author : mashenquan
|
||||
@File : graph_repository.py
|
||||
@Desc : Superclass for graph repository.
|
||||
@Desc : Superclass for graph repository. This script defines a superclass for a graph repository, providing a
|
||||
foundation for specific implementations.
|
||||
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.repo_parser import ClassInfo, ClassRelationship, RepoFileInfo
|
||||
from metagpt.utils.common import concat_namespace
|
||||
from metagpt.repo_parser import DotClassInfo, DotClassRelationship, RepoFileInfo
|
||||
from metagpt.utils.common import concat_namespace, split_namespace
|
||||
|
||||
|
||||
class GraphKeyword:
|
||||
"""Basic words for a Graph database.
|
||||
|
||||
This class defines a set of basic words commonly used in the context of a Graph database.
|
||||
"""
|
||||
|
||||
IS = "is"
|
||||
OF = "Of"
|
||||
ON = "On"
|
||||
|
|
@ -28,51 +35,149 @@ class GraphKeyword:
|
|||
SOURCE_CODE = "source_code"
|
||||
NULL = "<null>"
|
||||
GLOBAL_VARIABLE = "global_variable"
|
||||
CLASS_FUNCTION = "class_function"
|
||||
CLASS_METHOD = "class_method"
|
||||
CLASS_PROPERTY = "class_property"
|
||||
HAS_CLASS_FUNCTION = "has_class_function"
|
||||
HAS_CLASS_METHOD = "has_class_method"
|
||||
HAS_CLASS_PROPERTY = "has_class_property"
|
||||
HAS_CLASS = "has_class"
|
||||
HAS_DETAIL = "has_detail"
|
||||
HAS_PAGE_INFO = "has_page_info"
|
||||
HAS_CLASS_VIEW = "has_class_view"
|
||||
HAS_SEQUENCE_VIEW = "has_sequence_view"
|
||||
HAS_ARGS_DESC = "has_args_desc"
|
||||
HAS_TYPE_DESC = "has_type_desc"
|
||||
HAS_SEQUENCE_VIEW_VER = "has_sequence_view_ver"
|
||||
HAS_CLASS_USE_CASE = "has_class_use_case"
|
||||
IS_COMPOSITE_OF = "is_composite_of"
|
||||
IS_AGGREGATE_OF = "is_aggregate_of"
|
||||
HAS_PARTICIPANT = "has_participant"
|
||||
|
||||
|
||||
class SPO(BaseModel):
|
||||
"""Graph repository record type.
|
||||
|
||||
This class represents a record in a graph repository with three components:
|
||||
- Subject: The subject of the triple.
|
||||
- Predicate: The predicate describing the relationship between the subject and the object.
|
||||
- Object: The object of the triple.
|
||||
|
||||
Attributes:
|
||||
subject (str): The subject of the triple.
|
||||
predicate (str): The predicate describing the relationship.
|
||||
object_ (str): The object of the triple.
|
||||
|
||||
Example:
|
||||
spo_record = SPO(subject="Node1", predicate="connects_to", object_="Node2")
|
||||
# Represents a triple: Node1 connects_to Node2
|
||||
"""
|
||||
|
||||
subject: str
|
||||
predicate: str
|
||||
object_: str
|
||||
|
||||
|
||||
class GraphRepository(ABC):
|
||||
"""Abstract base class for a Graph Repository.
|
||||
|
||||
This class defines the interface for a graph repository, providing methods for inserting, selecting,
|
||||
deleting, and saving graph data. Concrete implementations of this class must provide functionality
|
||||
for these operations.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, **kwargs):
|
||||
self._repo_name = name
|
||||
self._kwargs = kwargs
|
||||
|
||||
@abstractmethod
|
||||
async def insert(self, subject: str, predicate: str, object_: str):
|
||||
pass
|
||||
"""Insert a new triple into the graph repository.
|
||||
|
||||
@abstractmethod
|
||||
async def upsert(self, subject: str, predicate: str, object_: str):
|
||||
pass
|
||||
Args:
|
||||
subject (str): The subject of the triple.
|
||||
predicate (str): The predicate describing the relationship.
|
||||
object_ (str): The object of the triple.
|
||||
|
||||
@abstractmethod
|
||||
async def update(self, subject: str, predicate: str, object_: str):
|
||||
Example:
|
||||
await my_repository.insert(subject="Node1", predicate="connects_to", object_="Node2")
|
||||
# Inserts a triple: Node1 connects_to Node2 into the graph repository.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def select(self, subject: str = None, predicate: str = None, object_: str = None) -> List[SPO]:
|
||||
"""Retrieve triples from the graph repository based on specified criteria.
|
||||
|
||||
Args:
|
||||
subject (str, optional): The subject of the triple to filter by.
|
||||
predicate (str, optional): The predicate describing the relationship to filter by.
|
||||
object_ (str, optional): The object of the triple to filter by.
|
||||
|
||||
Returns:
|
||||
List[SPO]: A list of SPO objects representing the selected triples.
|
||||
|
||||
Example:
|
||||
selected_triples = await my_repository.select(subject="Node1", predicate="connects_to")
|
||||
# Retrieves triples where Node1 is the subject and the predicate is 'connects_to'.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, subject: str = None, predicate: str = None, object_: str = None) -> int:
|
||||
"""Delete triples from the graph repository based on specified criteria.
|
||||
|
||||
Args:
|
||||
subject (str, optional): The subject of the triple to filter by.
|
||||
predicate (str, optional): The predicate describing the relationship to filter by.
|
||||
object_ (str, optional): The object of the triple to filter by.
|
||||
|
||||
Returns:
|
||||
int: The number of triples deleted from the repository.
|
||||
|
||||
Example:
|
||||
deleted_count = await my_repository.delete(subject="Node1", predicate="connects_to")
|
||||
# Deletes triples where Node1 is the subject and the predicate is 'connects_to'.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def save(self):
|
||||
"""Save any changes made to the graph repository.
|
||||
|
||||
Example:
|
||||
await my_repository.save()
|
||||
# Persists any changes made to the graph repository.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Get the name of the graph repository."""
|
||||
return self._repo_name
|
||||
|
||||
@staticmethod
|
||||
async def update_graph_db_with_file_info(graph_db: "GraphRepository", file_info: RepoFileInfo):
|
||||
"""Insert information of RepoFileInfo into the specified graph repository.
|
||||
|
||||
This function updates the provided graph repository with information from the given RepoFileInfo object.
|
||||
The function inserts triples related to various dimensions such as file type, class, class method, function,
|
||||
global variable, and page info.
|
||||
|
||||
Triple Patterns:
|
||||
- (?, is, [file type])
|
||||
- (?, has class, ?)
|
||||
- (?, is, [class])
|
||||
- (?, has class method, ?)
|
||||
- (?, has function, ?)
|
||||
- (?, is, [function])
|
||||
- (?, is, global variable)
|
||||
- (?, has page info, ?)
|
||||
|
||||
Args:
|
||||
graph_db (GraphRepository): The graph repository object to be updated.
|
||||
file_info (RepoFileInfo): The RepoFileInfo object containing information to be inserted.
|
||||
|
||||
Example:
|
||||
await update_graph_db_with_file_info(my_graph_repo, my_file_info)
|
||||
# Updates 'my_graph_repo' with information from 'my_file_info'.
|
||||
"""
|
||||
await graph_db.insert(subject=file_info.file, predicate=GraphKeyword.IS, object_=GraphKeyword.SOURCE_CODE)
|
||||
file_types = {".py": "python", ".js": "javascript"}
|
||||
file_type = file_types.get(Path(file_info.file).suffix, GraphKeyword.NULL)
|
||||
|
|
@ -95,13 +200,13 @@ class GraphRepository(ABC):
|
|||
for fn in methods:
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(file_info.file, class_name),
|
||||
predicate=GraphKeyword.HAS_CLASS_FUNCTION,
|
||||
predicate=GraphKeyword.HAS_CLASS_METHOD,
|
||||
object_=concat_namespace(file_info.file, class_name, fn),
|
||||
)
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(file_info.file, class_name, fn),
|
||||
predicate=GraphKeyword.IS,
|
||||
object_=GraphKeyword.CLASS_FUNCTION,
|
||||
object_=GraphKeyword.CLASS_METHOD,
|
||||
)
|
||||
for f in file_info.functions:
|
||||
# file -> function
|
||||
|
|
@ -133,7 +238,34 @@ class GraphRepository(ABC):
|
|||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_graph_db_with_class_views(graph_db: "GraphRepository", class_views: List[ClassInfo]):
|
||||
async def update_graph_db_with_class_views(graph_db: "GraphRepository", class_views: List[DotClassInfo]):
|
||||
"""Insert dot format class information into the specified graph repository.
|
||||
|
||||
This function updates the provided graph repository with class information from the given list of DotClassInfo objects.
|
||||
The function inserts triples related to various aspects of class views, including source code, file type, class,
|
||||
class property, class detail, method, composition, and aggregation.
|
||||
|
||||
Triple Patterns:
|
||||
- (?, is, source code)
|
||||
- (?, is, file type)
|
||||
- (?, has class, ?)
|
||||
- (?, is, class)
|
||||
- (?, has class property, ?)
|
||||
- (?, is, class property)
|
||||
- (?, has detail, ?)
|
||||
- (?, has method, ?)
|
||||
- (?, is composite of, ?)
|
||||
- (?, is aggregate of, ?)
|
||||
|
||||
Args:
|
||||
graph_db (GraphRepository): The graph repository object to be updated.
|
||||
class_views (List[DotClassInfo]): List of DotClassInfo objects containing class information to be inserted.
|
||||
|
||||
|
||||
Example:
|
||||
await update_graph_db_with_class_views(my_graph_repo, [class_info1, class_info2])
|
||||
# Updates 'my_graph_repo' with class information from the provided list of DotClassInfo objects.
|
||||
"""
|
||||
for c in class_views:
|
||||
filename, _ = c.package.split(":", 1)
|
||||
await graph_db.insert(subject=filename, predicate=GraphKeyword.IS, object_=GraphKeyword.SOURCE_CODE)
|
||||
|
|
@ -146,6 +278,7 @@ class GraphRepository(ABC):
|
|||
predicate=GraphKeyword.IS,
|
||||
object_=GraphKeyword.CLASS,
|
||||
)
|
||||
await graph_db.insert(subject=c.package, predicate=GraphKeyword.HAS_DETAIL, object_=c.model_dump_json())
|
||||
for vn, vt in c.attributes.items():
|
||||
# class -> property
|
||||
await graph_db.insert(
|
||||
|
|
@ -160,33 +293,61 @@ class GraphRepository(ABC):
|
|||
object_=GraphKeyword.CLASS_PROPERTY,
|
||||
)
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(c.package, vn), predicate=GraphKeyword.HAS_TYPE_DESC, object_=vt
|
||||
subject=concat_namespace(c.package, vn),
|
||||
predicate=GraphKeyword.HAS_DETAIL,
|
||||
object_=vt.model_dump_json(),
|
||||
)
|
||||
for fn, desc in c.methods.items():
|
||||
if "</I>" in desc and "<I>" not in desc:
|
||||
logger.error(desc)
|
||||
for fn, ft in c.methods.items():
|
||||
# class -> function
|
||||
await graph_db.insert(
|
||||
subject=c.package,
|
||||
predicate=GraphKeyword.HAS_CLASS_FUNCTION,
|
||||
predicate=GraphKeyword.HAS_CLASS_METHOD,
|
||||
object_=concat_namespace(c.package, fn),
|
||||
)
|
||||
# function detail
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(c.package, fn),
|
||||
predicate=GraphKeyword.IS,
|
||||
object_=GraphKeyword.CLASS_FUNCTION,
|
||||
object_=GraphKeyword.CLASS_METHOD,
|
||||
)
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(c.package, fn),
|
||||
predicate=GraphKeyword.HAS_ARGS_DESC,
|
||||
object_=desc,
|
||||
predicate=GraphKeyword.HAS_DETAIL,
|
||||
object_=ft.model_dump_json(),
|
||||
)
|
||||
for i in c.compositions:
|
||||
await graph_db.insert(
|
||||
subject=c.package, predicate=GraphKeyword.IS_COMPOSITE_OF, object_=concat_namespace("?", i)
|
||||
)
|
||||
for i in c.aggregations:
|
||||
await graph_db.insert(
|
||||
subject=c.package, predicate=GraphKeyword.IS_AGGREGATE_OF, object_=concat_namespace("?", i)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_graph_db_with_class_relationship_views(
|
||||
graph_db: "GraphRepository", relationship_views: List[ClassRelationship]
|
||||
graph_db: "GraphRepository", relationship_views: List[DotClassRelationship]
|
||||
):
|
||||
"""Insert class relationships and labels into the specified graph repository.
|
||||
|
||||
This function updates the provided graph repository with class relationship information from the given list
|
||||
of DotClassRelationship objects. The function inserts triples representing relationships and labels between
|
||||
classes.
|
||||
|
||||
Triple Patterns:
|
||||
- (?, is relationship of, ?)
|
||||
- (?, is relationship on, ?)
|
||||
|
||||
Args:
|
||||
graph_db (GraphRepository): The graph repository object to be updated.
|
||||
relationship_views (List[DotClassRelationship]): List of DotClassRelationship objects containing
|
||||
class relationship information to be inserted.
|
||||
|
||||
Example:
|
||||
await update_graph_db_with_class_relationship_views(my_graph_repo, [relationship1, relationship2])
|
||||
# Updates 'my_graph_repo' with class relationship information from the provided list of DotClassRelationship objects.
|
||||
|
||||
"""
|
||||
for r in relationship_views:
|
||||
await graph_db.insert(
|
||||
subject=r.src, predicate=GraphKeyword.IS + r.relationship + GraphKeyword.OF, object_=r.dest
|
||||
|
|
@ -198,3 +359,32 @@ class GraphRepository(ABC):
|
|||
predicate=GraphKeyword.IS + r.relationship + GraphKeyword.ON,
|
||||
object_=concat_namespace(r.dest, r.label),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def rebuild_composition_relationship(graph_db: "GraphRepository"):
|
||||
"""Append namespace-prefixed information to relationship SPO (Subject-Predicate-Object) objects in the graph
|
||||
repository.
|
||||
|
||||
This function updates the provided graph repository by appending namespace-prefixed information to existing
|
||||
relationship SPO objects.
|
||||
|
||||
Args:
|
||||
graph_db (GraphRepository): The graph repository object to be updated.
|
||||
"""
|
||||
classes = await graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
|
||||
mapping = defaultdict(list)
|
||||
for c in classes:
|
||||
name = split_namespace(c.subject)[-1]
|
||||
mapping[name].append(c.subject)
|
||||
|
||||
rows = await graph_db.select(predicate=GraphKeyword.IS_COMPOSITE_OF)
|
||||
for r in rows:
|
||||
ns, class_ = split_namespace(r.object_)
|
||||
if ns != "?":
|
||||
continue
|
||||
val = mapping[class_]
|
||||
if len(val) != 1:
|
||||
continue
|
||||
ns_name = val[0]
|
||||
await graph_db.delete(subject=r.subject, predicate=r.predicate, object_=r.object_)
|
||||
await graph_db.insert(subject=r.subject, predicate=r.predicate, object_=ns_name)
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ from metagpt.const import (
|
|||
TASK_PDF_FILE_REPO,
|
||||
TEST_CODES_FILE_REPO,
|
||||
TEST_OUTPUTS_FILE_REPO,
|
||||
VISUAL_GRAPH_REPO_FILE_REPO,
|
||||
)
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
|
|
@ -69,6 +70,7 @@ class ResourceFileRepositories(FileRepository):
|
|||
code_summary: FileRepository
|
||||
sd_output: FileRepository
|
||||
code_plan_and_change: FileRepository
|
||||
graph_repo: FileRepository
|
||||
|
||||
def __init__(self, git_repo):
|
||||
super().__init__(git_repo=git_repo, relative_path=RESOURCES_FILE_REPO)
|
||||
|
|
@ -82,6 +84,7 @@ class ResourceFileRepositories(FileRepository):
|
|||
self.code_summary = git_repo.new_file_repository(relative_path=CODE_SUMMARIES_PDF_FILE_REPO)
|
||||
self.sd_output = git_repo.new_file_repository(relative_path=SD_OUTPUT_FILE_REPO)
|
||||
self.code_plan_and_change = git_repo.new_file_repository(relative_path=CODE_PLAN_AND_CHANGE_PDF_FILE_REPO)
|
||||
self.graph_repo = git_repo.new_file_repository(relative_path=VISUAL_GRAPH_REPO_FILE_REPO)
|
||||
|
||||
|
||||
class ProjectRepo(FileRepository):
|
||||
|
|
@ -133,6 +136,7 @@ class ProjectRepo(FileRepository):
|
|||
code_files = self.with_src_path(path=git_workdir / git_workdir.name).srcs.all_files
|
||||
if not code_files:
|
||||
return False
|
||||
return bool(code_files)
|
||||
|
||||
def with_src_path(self, path: str | Path) -> ProjectRepo:
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -119,6 +119,7 @@ def repair_json_format(output: str) -> str:
|
|||
logger.info(f"repair_json_format: {'}]'}")
|
||||
elif output.startswith("{") and output.endswith("]"):
|
||||
output = output[:-1] + "}"
|
||||
|
||||
# remove comments in output json string, after json value content, maybe start with #, maybe start with //
|
||||
arr = output.split("\n")
|
||||
new_arr = []
|
||||
|
|
@ -208,6 +209,17 @@ def repair_invalid_json(output: str, error: str) -> str:
|
|||
elif (rline[col_no] in ["'", '"']) and (line.startswith('"') or line.startswith("'")) and "," not in line:
|
||||
# problem, `"""` or `'''` without `,`
|
||||
new_line = f",{line}"
|
||||
elif col_no - 1 >= 0 and rline[col_no - 1] in ['"', "'"]:
|
||||
# backslash problem like \" in the output
|
||||
char = rline[col_no - 1]
|
||||
nearest_char_idx = rline[col_no:].find(char)
|
||||
new_line = (
|
||||
rline[: col_no - 1]
|
||||
+ "\\"
|
||||
+ rline[col_no - 1 : col_no + nearest_char_idx]
|
||||
+ "\\"
|
||||
+ rline[col_no + nearest_char_idx :]
|
||||
)
|
||||
elif '",' not in line and "," not in line and '"' not in line:
|
||||
new_line = f'{line}",'
|
||||
elif not line.endswith(","):
|
||||
|
|
|
|||
|
|
@ -35,9 +35,111 @@ TOKEN_COSTS = {
|
|||
"glm-3-turbo": {"prompt": 0.0007, "completion": 0.0007}, # 128k version, prompt + completion tokens=0.005¥/k-tokens
|
||||
"glm-4": {"prompt": 0.014, "completion": 0.014}, # 128k version, prompt + completion tokens=0.1¥/k-tokens
|
||||
"gemini-pro": {"prompt": 0.00025, "completion": 0.0005},
|
||||
"moonshot-v1-8k": {"prompt": 0.012, "completion": 0.012}, # prompt + completion tokens=0.012¥/k-tokens
|
||||
"moonshot-v1-32k": {"prompt": 0.024, "completion": 0.024},
|
||||
"moonshot-v1-128k": {"prompt": 0.06, "completion": 0.06},
|
||||
"open-mistral-7b": {"prompt": 0.00025, "completion": 0.00025},
|
||||
"open-mixtral-8x7b": {"prompt": 0.0007, "completion": 0.0007},
|
||||
"mistral-small-latest": {"prompt": 0.002, "completion": 0.006},
|
||||
"mistral-medium-latest": {"prompt": 0.0027, "completion": 0.0081},
|
||||
"mistral-large-latest": {"prompt": 0.008, "completion": 0.024},
|
||||
"claude-instant-1.2": {"prompt": 0.0008, "completion": 0.0024},
|
||||
"claude-2.0": {"prompt": 0.008, "completion": 0.024},
|
||||
"claude-2.1": {"prompt": 0.008, "completion": 0.024},
|
||||
"claude-3-sonnet-20240229": {"prompt": 0.003, "completion": 0.015},
|
||||
"claude-3-opus-20240229": {"prompt": 0.015, "completion": 0.075},
|
||||
}
|
||||
|
||||
|
||||
"""
|
||||
QianFan Token Price https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7#tokens%E5%90%8E%E4%BB%98%E8%B4%B9
|
||||
Due to QianFan has multi price strategies, we unify `Tokens post-payment` as a statistical method.
|
||||
"""
|
||||
QIANFAN_MODEL_TOKEN_COSTS = {
|
||||
"ERNIE-Bot-4": {"prompt": 0.017, "completion": 0.017},
|
||||
"ERNIE-Bot-8k": {"prompt": 0.0034, "completion": 0.0067},
|
||||
"ERNIE-Bot": {"prompt": 0.0017, "completion": 0.0017},
|
||||
"ERNIE-Bot-turbo": {"prompt": 0.0011, "completion": 0.0011},
|
||||
"EB-turbo-AppBuilder": {"prompt": 0.0011, "completion": 0.0011},
|
||||
"ERNIE-Speed": {"prompt": 0.00056, "completion": 0.0011},
|
||||
"BLOOMZ-7B": {"prompt": 0.00056, "completion": 0.00056},
|
||||
"Llama-2-7B-Chat": {"prompt": 0.00056, "completion": 0.00056},
|
||||
"Llama-2-13B-Chat": {"prompt": 0.00084, "completion": 0.00084},
|
||||
"Llama-2-70B-Chat": {"prompt": 0.0049, "completion": 0.0049},
|
||||
"ChatGLM2-6B-32K": {"prompt": 0.00056, "completion": 0.00056},
|
||||
"AquilaChat-7B": {"prompt": 0.00056, "completion": 0.00056},
|
||||
"Mixtral-8x7B-Instruct": {"prompt": 0.0049, "completion": 0.0049},
|
||||
"SQLCoder-7B": {"prompt": 0.00056, "completion": 0.00056},
|
||||
"CodeLlama-7B-Instruct": {"prompt": 0.00056, "completion": 0.00056},
|
||||
"XuanYuan-70B-Chat-4bit": {"prompt": 0.0049, "completion": 0.0049},
|
||||
"Qianfan-BLOOMZ-7B-compressed": {"prompt": 0.00056, "completion": 0.00056},
|
||||
"Qianfan-Chinese-Llama-2-7B": {"prompt": 0.00056, "completion": 0.00056},
|
||||
"Qianfan-Chinese-Llama-2-13B": {"prompt": 0.00084, "completion": 0.00084},
|
||||
"ChatLaw": {"prompt": 0.0011, "completion": 0.0011},
|
||||
"Yi-34B-Chat": {"prompt": 0.0, "completion": 0.0},
|
||||
}
|
||||
|
||||
QIANFAN_ENDPOINT_TOKEN_COSTS = {
|
||||
"completions_pro": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Bot-4"],
|
||||
"ernie_bot_8k": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Bot-8k"],
|
||||
"completions": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Bot"],
|
||||
"eb-instant": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Bot-turbo"],
|
||||
"ai_apaas": QIANFAN_MODEL_TOKEN_COSTS["EB-turbo-AppBuilder"],
|
||||
"ernie_speed": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Speed"],
|
||||
"bloomz_7b1": QIANFAN_MODEL_TOKEN_COSTS["BLOOMZ-7B"],
|
||||
"llama_2_7b": QIANFAN_MODEL_TOKEN_COSTS["Llama-2-7B-Chat"],
|
||||
"llama_2_13b": QIANFAN_MODEL_TOKEN_COSTS["Llama-2-13B-Chat"],
|
||||
"llama_2_70b": QIANFAN_MODEL_TOKEN_COSTS["Llama-2-70B-Chat"],
|
||||
"chatglm2_6b_32k": QIANFAN_MODEL_TOKEN_COSTS["ChatGLM2-6B-32K"],
|
||||
"aquilachat_7b": QIANFAN_MODEL_TOKEN_COSTS["AquilaChat-7B"],
|
||||
"mixtral_8x7b_instruct": QIANFAN_MODEL_TOKEN_COSTS["Mixtral-8x7B-Instruct"],
|
||||
"sqlcoder_7b": QIANFAN_MODEL_TOKEN_COSTS["SQLCoder-7B"],
|
||||
"codellama_7b_instruct": QIANFAN_MODEL_TOKEN_COSTS["CodeLlama-7B-Instruct"],
|
||||
"xuanyuan_70b_chat": QIANFAN_MODEL_TOKEN_COSTS["XuanYuan-70B-Chat-4bit"],
|
||||
"qianfan_bloomz_7b_compressed": QIANFAN_MODEL_TOKEN_COSTS["Qianfan-BLOOMZ-7B-compressed"],
|
||||
"qianfan_chinese_llama_2_7b": QIANFAN_MODEL_TOKEN_COSTS["Qianfan-Chinese-Llama-2-7B"],
|
||||
"qianfan_chinese_llama_2_13b": QIANFAN_MODEL_TOKEN_COSTS["Qianfan-Chinese-Llama-2-13B"],
|
||||
"chatlaw": QIANFAN_MODEL_TOKEN_COSTS["ChatLaw"],
|
||||
"yi_34b_chat": QIANFAN_MODEL_TOKEN_COSTS["Yi-34B-Chat"],
|
||||
}
|
||||
|
||||
"""
|
||||
DashScope Token price https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing
|
||||
Different model has different detail page. Attention, some model are free for a limited time.
|
||||
"""
|
||||
DASHSCOPE_TOKEN_COSTS = {
|
||||
"qwen-turbo": {"prompt": 0.0011, "completion": 0.0011},
|
||||
"qwen-plus": {"prompt": 0.0028, "completion": 0.0028},
|
||||
"qwen-max": {"prompt": 0.0, "completion": 0.0},
|
||||
"qwen-max-1201": {"prompt": 0.0, "completion": 0.0},
|
||||
"qwen-max-longcontext": {"prompt": 0.0, "completion": 0.0},
|
||||
"llama2-7b-chat-v2": {"prompt": 0.0, "completion": 0.0},
|
||||
"llama2-13b-chat-v2": {"prompt": 0.0, "completion": 0.0},
|
||||
"qwen-72b-chat": {"prompt": 0.0, "completion": 0.0},
|
||||
"qwen-14b-chat": {"prompt": 0.0011, "completion": 0.0011},
|
||||
"qwen-7b-chat": {"prompt": 0.00084, "completion": 0.00084},
|
||||
"qwen-1.8b-chat": {"prompt": 0.0, "completion": 0.0},
|
||||
"baichuan2-13b-chat-v1": {"prompt": 0.0011, "completion": 0.0011},
|
||||
"baichuan2-7b-chat-v1": {"prompt": 0.00084, "completion": 0.00084},
|
||||
"baichuan-7b-v1": {"prompt": 0.0, "completion": 0.0},
|
||||
"chatglm-6b-v2": {"prompt": 0.0011, "completion": 0.0011},
|
||||
"chatglm3-6b": {"prompt": 0.0, "completion": 0.0},
|
||||
"ziya-llama-13b-v1": {"prompt": 0.0, "completion": 0.0}, # no price page, judge it as free
|
||||
"dolly-12b-v2": {"prompt": 0.0, "completion": 0.0},
|
||||
"belle-llama-13b-2m-v1": {"prompt": 0.0, "completion": 0.0},
|
||||
"moss-moon-003-sft-v1": {"prompt": 0.0, "completion": 0.0},
|
||||
"chatyuan-large-v2": {"prompt": 0.0, "completion": 0.0},
|
||||
"billa-7b-sft-v1": {"prompt": 0.0, "completion": 0.0},
|
||||
}
|
||||
|
||||
|
||||
FIREWORKS_GRADE_TOKEN_COSTS = {
|
||||
"-1": {"prompt": 0.0, "completion": 0.0}, # abnormal condition
|
||||
"16": {"prompt": 0.2, "completion": 0.8}, # 16 means model size <= 16B; 0.2 means $0.2/1M tokens
|
||||
"80": {"prompt": 0.7, "completion": 2.8}, # 80 means 16B < model size <= 80B
|
||||
"mixtral-8x7b": {"prompt": 0.4, "completion": 1.6},
|
||||
}
|
||||
|
||||
TOKEN_MAX = {
|
||||
"gpt-3.5-turbo": 4096,
|
||||
"gpt-3.5-turbo-0301": 4096,
|
||||
|
|
@ -61,6 +163,19 @@ TOKEN_MAX = {
|
|||
"glm-3-turbo": 128000,
|
||||
"glm-4": 128000,
|
||||
"gemini-pro": 32768,
|
||||
"moonshot-v1-8k": 8192,
|
||||
"moonshot-v1-32k": 32768,
|
||||
"moonshot-v1-128k": 128000,
|
||||
"open-mistral-7b": 8192,
|
||||
"open-mixtral-8x7b": 32768,
|
||||
"mistral-small-latest": 32768,
|
||||
"mistral-medium-latest": 32768,
|
||||
"mistral-large-latest": 32768,
|
||||
"claude-instant-1.2": 100000,
|
||||
"claude-2.0": 100000,
|
||||
"claude-2.1": 200000,
|
||||
"claude-3-sonnet-20240229": 200000,
|
||||
"claude-3-opus-20240229": 200000,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
162
metagpt/utils/visual_graph_repo.py
Normal file
162
metagpt/utils/visual_graph_repo.py
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/19
|
||||
@Author : mashenquan
|
||||
@File : visualize_graph.py
|
||||
@Desc : Visualization tool to visualize the class diagrams or sequence diagrams of the graph repository.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from abc import ABC
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.const import AGGREGATION, COMPOSITION, GENERALIZATION
|
||||
from metagpt.schema import UMLClassView
|
||||
from metagpt.utils.common import split_namespace
|
||||
from metagpt.utils.di_graph_repository import DiGraphRepository
|
||||
from metagpt.utils.graph_repository import GraphKeyword, GraphRepository
|
||||
|
||||
|
||||
class _VisualClassView(BaseModel):
|
||||
"""Protected class used by VisualGraphRepo internally.
|
||||
|
||||
Attributes:
|
||||
package (str): The package associated with the class.
|
||||
uml (Optional[UMLClassView]): Optional UMLClassView associated with the class.
|
||||
generalizations (List[str]): List of generalizations for the class.
|
||||
compositions (List[str]): List of compositions for the class.
|
||||
aggregations (List[str]): List of aggregations for the class.
|
||||
"""
|
||||
|
||||
package: str
|
||||
uml: Optional[UMLClassView] = None
|
||||
generalizations: List[str] = Field(default_factory=list)
|
||||
compositions: List[str] = Field(default_factory=list)
|
||||
aggregations: List[str] = Field(default_factory=list)
|
||||
|
||||
def get_mermaid(self, align: int = 1) -> str:
|
||||
"""Creates a Markdown Mermaid class diagram text.
|
||||
|
||||
Args:
|
||||
align (int): Indent count used for alignment.
|
||||
|
||||
Returns:
|
||||
str: The Markdown text representing the Mermaid class diagram.
|
||||
"""
|
||||
if not self.uml:
|
||||
return ""
|
||||
prefix = "\t" * align
|
||||
|
||||
mermaid_txt = self.uml.get_mermaid(align=align)
|
||||
for i in self.generalizations:
|
||||
mermaid_txt += f"{prefix}{i} <|-- {self.name}\n"
|
||||
for i in self.compositions:
|
||||
mermaid_txt += f"{prefix}{i} *-- {self.name}\n"
|
||||
for i in self.aggregations:
|
||||
mermaid_txt += f"{prefix}{i} o-- {self.name}\n"
|
||||
return mermaid_txt
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Returns the class name without the namespace prefix."""
|
||||
return split_namespace(self.package)[-1]
|
||||
|
||||
|
||||
class VisualGraphRepo(ABC):
|
||||
"""Abstract base class for VisualGraphRepo."""
|
||||
|
||||
graph_db: GraphRepository
|
||||
|
||||
def __init__(self, graph_db):
|
||||
self.graph_db = graph_db
|
||||
|
||||
|
||||
class VisualDiGraphRepo(VisualGraphRepo):
|
||||
"""Implementation of VisualGraphRepo for DiGraph graph repository.
|
||||
|
||||
This class extends VisualGraphRepo to provide specific functionality for a graph repository using DiGraph.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
async def load_from(cls, filename: str | Path):
|
||||
"""Load a VisualDiGraphRepo instance from a file."""
|
||||
graph_db = await DiGraphRepository.load_from(str(filename))
|
||||
return cls(graph_db=graph_db)
|
||||
|
||||
async def get_mermaid_class_view(self) -> str:
|
||||
"""
|
||||
Returns a Markdown Mermaid class diagram code block object.
|
||||
"""
|
||||
rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
|
||||
mermaid_txt = "classDiagram\n"
|
||||
for r in rows:
|
||||
v = await self._get_class_view(ns_class_name=r.subject)
|
||||
mermaid_txt += v.get_mermaid()
|
||||
return mermaid_txt
|
||||
|
||||
async def _get_class_view(self, ns_class_name: str) -> _VisualClassView:
|
||||
"""Returns the Markdown Mermaid class diagram code block object for the specified class."""
|
||||
rows = await self.graph_db.select(subject=ns_class_name)
|
||||
class_view = _VisualClassView(package=ns_class_name)
|
||||
for r in rows:
|
||||
if r.predicate == GraphKeyword.HAS_CLASS_VIEW:
|
||||
class_view.uml = UMLClassView.model_validate_json(r.object_)
|
||||
elif r.predicate == GraphKeyword.IS + GENERALIZATION + GraphKeyword.OF:
|
||||
name = split_namespace(r.object_)[-1]
|
||||
name = self._refine_name(name)
|
||||
if name:
|
||||
class_view.generalizations.append(name)
|
||||
elif r.predicate == GraphKeyword.IS + COMPOSITION + GraphKeyword.OF:
|
||||
name = split_namespace(r.object_)[-1]
|
||||
name = self._refine_name(name)
|
||||
if name:
|
||||
class_view.compositions.append(name)
|
||||
elif r.predicate == GraphKeyword.IS + AGGREGATION + GraphKeyword.OF:
|
||||
name = split_namespace(r.object_)[-1]
|
||||
name = self._refine_name(name)
|
||||
if name:
|
||||
class_view.aggregations.append(name)
|
||||
return class_view
|
||||
|
||||
async def get_mermaid_sequence_views(self) -> List[(str, str)]:
|
||||
"""Returns all Markdown sequence diagrams with their corresponding graph repository keys."""
|
||||
sequence_views = []
|
||||
rows = await self.graph_db.select(predicate=GraphKeyword.HAS_SEQUENCE_VIEW)
|
||||
for r in rows:
|
||||
sequence_views.append((r.subject, r.object_))
|
||||
return sequence_views
|
||||
|
||||
@staticmethod
|
||||
def _refine_name(name: str) -> str:
|
||||
"""Removes impurity content from the given name.
|
||||
|
||||
Example:
|
||||
>>> _refine_name("int")
|
||||
""
|
||||
|
||||
>>> _refine_name('"Class1"')
|
||||
'Class1'
|
||||
|
||||
>>> _refine_name("pkg.Class1")
|
||||
"Class1"
|
||||
"""
|
||||
name = re.sub(r'^[\'"\\\(\)]+|[\'"\\\(\)]+$', "", name)
|
||||
if name in ["int", "float", "bool", "str", "list", "tuple", "set", "dict", "None"]:
|
||||
return ""
|
||||
if "." in name:
|
||||
name = name.split(".")[-1]
|
||||
|
||||
return name
|
||||
|
||||
async def get_mermaid_sequence_view_versions(self) -> List[(str, str)]:
|
||||
"""Returns all versioned Markdown sequence diagrams with their corresponding graph repository keys."""
|
||||
sequence_views = []
|
||||
rows = await self.graph_db.select(predicate=GraphKeyword.HAS_SEQUENCE_VIEW_VER)
|
||||
for r in rows:
|
||||
sequence_views.append((r.subject, r.object_))
|
||||
return sequence_views
|
||||
Loading…
Add table
Add a link
Reference in a new issue