init project

This commit is contained in:
吴承霖 2023-06-30 17:10:48 +08:00
commit c871144507
204 changed files with 7220 additions and 0 deletions

11
metagpt/utils/__init__.py Normal file
View file

@ -0,0 +1,11 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/4/29 15:50
@Author : alexanderwu
@File : __init__.py
"""
from metagpt.utils.singleton import Singleton
from metagpt.utils.read_document import read_docx
from metagpt.utils.token_counter import TOKEN_COSTS, count_string_tokens, count_message_tokens

186
metagpt/utils/common.py Normal file
View file

@ -0,0 +1,186 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/4/29 16:07
@Author : alexanderwu
@File : common.py
"""
import re
import ast
import subprocess
import inspect
from pathlib import Path
from metagpt.const import PROJECT_ROOT, TMP
from metagpt.logs import logger
class CodeParser:
@classmethod
def parse_block(cls, block: str, text: str) -> str:
blocks = cls.parse_blocks(text)
for k, v in blocks.items():
if block in k:
return v
return ""
@classmethod
def parse_blocks(cls, text: str):
# 首先根据"##"将文本分割成不同的block
blocks = text.split("##")
# 创建一个字典用于存储每个block的标题和内容
block_dict = {}
# 遍历所有的block
for block in blocks:
# 如果block不为空则继续处理
if block.strip() != "":
# 将block的标题和内容分开并分别去掉前后的空白字符
block_title, block_content = block.split("\n", 1)
block_dict[block_title.strip()] = block_content.strip()
return block_dict
@classmethod
def parse_code(cls, block: str, text: str, lang: str="") -> str:
if block:
text = cls.parse_block(block, text)
pattern = rf'```{lang}.*?\s+(.*?)```'
match = re.search(pattern, text, re.DOTALL)
if match:
code = match.group(1)
else:
logger.error(f"{pattern} not match following text:")
logger.error(text)
raise Exception
return code
@classmethod
def parse_str(cls, block: str, text: str, lang: str=""):
code = cls.parse_code(block, text, lang)
code = code.split("=")[-1]
code = code.strip().strip("'").strip("\"")
return code
@classmethod
def parse_file_list(cls, block: str, text: str, lang: str="") -> list[str]:
# Regular expression pattern to find the tasks list.
code = cls.parse_code(block, text, lang)
pattern = r'\s*(.*=.*)?(\[.*\])'
# Extract tasks list string using regex.
match = re.search(pattern, code, re.DOTALL)
if match:
tasks_list_str = match.group(2)
# Convert string representation of list to a Python list using ast.literal_eval.
tasks = ast.literal_eval(tasks_list_str)
else:
raise Exception
return tasks
class NoMoneyException(Exception):
"""Raised when the operation cannot be completed due to insufficient funds"""
def __init__(self, amount, message="Insufficient funds"):
self.amount = amount
self.message = message
super().__init__(self.message)
def __str__(self):
return f'{self.message} -> Amount required: {self.amount}'
def print_members(module, indent=0):
"""
https://stackoverflow.com/questions/1796180/how-can-i-get-a-list-of-all-classes-within-current-module-in-python
:param module:
:param indent:
:return:
"""
prefix = ' ' * indent
for name, obj in inspect.getmembers(module):
print(name, obj)
if inspect.isclass(obj):
print(f'{prefix}Class: {name}')
# print the methods within the class
if name in ['__class__', '__base__']:
continue
print_members(obj, indent + 2)
elif inspect.isfunction(obj):
print(f'{prefix}Function: {name}')
elif inspect.ismethod(obj):
print(f'{prefix}Method: {name}')
def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height=2048):
"""suffix: png/svg/pdf"""
# Write the Mermaid code to a temporary file
tmp = Path(f'{output_file_without_suffix}.mmd')
logger.info(tmp)
logger.info(str(tmp))
tmp.write_text(mermaid_code)
for suffix in ['pdf', 'svg', 'png']:
output_file = f'{output_file_without_suffix}.{suffix}'
# Call the `mmdc` command to convert the Mermaid code to a PNG
subprocess.run(['mmdc', '-i', str(tmp), '-o', output_file, '-w', str(width), '-H', str(height)])
MMC1 = """classDiagram
class Main {
-SearchEngine search_engine
+main() str
}
class SearchEngine {
-Index index
-Ranking ranking
-Summary summary
+search(query: str) str
}
class Index {
-KnowledgeBase knowledge_base
+create_index(data: dict)
+query_index(query: str) list
}
class Ranking {
+rank_results(results: list) list
}
class Summary {
+summarize_results(results: list) str
}
class KnowledgeBase {
+update(data: dict)
+fetch_data(query: str) dict
}
Main --> SearchEngine
SearchEngine --> Index
SearchEngine --> Ranking
SearchEngine --> Summary
Index --> KnowledgeBase"""
MMC2 = """sequenceDiagram
participant M as Main
participant SE as SearchEngine
participant I as Index
participant R as Ranking
participant S as Summary
participant KB as KnowledgeBase
M->>SE: search(query)
SE->>I: query_index(query)
I->>KB: fetch_data(query)
KB-->>I: return data
I-->>SE: return results
SE->>R: rank_results(results)
R-->>SE: return ranked_results
SE->>S: summarize_results(ranked_results)
S-->>SE: return summary
SE-->>M: return summary"""
if __name__ == '__main__':
# logger.info(print_members(print_members))
mermaid_to_file(MMC1, PROJECT_ROOT / 'tmp/1.png')
mermaid_to_file(MMC2, PROJECT_ROOT / 'tmp/2.png')

View file

@ -0,0 +1,29 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/5/7 16:43
@Author : alexanderwu
@File : custom_aio_session.py
"""
import ssl
import aiohttp
import openai
class CustomAioSession:
async def __aenter__(self):
"""暂时使用自签署的ssl先忽略验证问题"""
# ssl_context = ssl.create_default_context()
# ssl_context.check_hostname = False
# ssl_context.verify_mode = ssl.CERT_NONE
headers = {"Accept-Encoding": "identity"} # Disable gzip encoding
custom_session = aiohttp.ClientSession(headers=headers)
openai.aiosession.set(custom_session)
return custom_session
async def __aexit__(self, exc_type, exc_val, exc_tb):
session = openai.aiosession.get()
if session:
await session.close()
openai.aiosession.set(None)

View file

@ -0,0 +1,23 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/4/29 15:45
@Author : alexanderwu
@File : read_document.py
"""
import docx
def read_docx(file_path: str) -> list:
"""打开docx文件"""
doc = docx.Document(file_path)
# 创建一个空列表,用于存储段落内容
paragraphs_list = []
# 遍历文档中的段落,并将其内容添加到列表中
for paragraph in doc.paragraphs:
paragraphs_list.append(paragraph.text)
return paragraphs_list

View file

@ -0,0 +1,22 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/5/11 16:15
@Author : alexanderwu
@File : singleton.py
"""
import abc
class Singleton(abc.ABCMeta, type):
"""
Singleton metaclass for ensuring only one instance of a class.
"""
_instances = {}
def __call__(cls, *args, **kwargs):
"""Call method for the singleton metaclass."""
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]

View file

@ -0,0 +1,71 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/5/18 00:40
@Author : alexanderwu
@File : token_counter.py
ref1: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
ref2: https://github.com/Significant-Gravitas/Auto-GPT/blob/master/autogpt/llm/token_counter.py
ref3: https://github.com/hwchase17/langchain/blob/master/langchain/chat_models/openai.py
"""
import tiktoken
from metagpt.schema import RawMessage
TOKEN_COSTS = {
"gpt-3.5-turbo": {"prompt": 0.002, "completion": 0.002},
"gpt-3.5-turbo-0301": {"prompt": 0.002, "completion": 0.002},
"gpt-4-0314": {"prompt": 0.03, "completion": 0.06},
"gpt-4": {"prompt": 0.03, "completion": 0.06},
"gpt-4-32k": {"prompt": 0.06, "completion": 0.12},
"gpt-4-32k-0314": {"prompt": 0.06, "completion": 0.12},
"text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0},
}
def count_message_tokens(messages: list[RawMessage], model="gpt-3.5-turbo-0301"):
"""Returns the number of tokens used by a list of messages."""
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
print("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
if model == "gpt-3.5-turbo":
print("Warning: gpt-3.5-turbo may change over time. Returning num tokens assuming gpt-3.5-turbo-0301.")
return count_message_tokens(messages, model="gpt-3.5-turbo-0301")
elif model == "gpt-4":
print("Warning: gpt-4 may change over time. Returning num tokens assuming gpt-4-0314.")
return count_message_tokens(messages, model="gpt-4-0314")
elif model == "gpt-3.5-turbo-0301":
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_name = -1 # if there's a name, the role is omitted
elif model == "gpt-4-0314":
tokens_per_message = 3
tokens_per_name = 1
else:
raise NotImplementedError(f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens.""")
num_tokens = 0
for message in messages:
num_tokens += tokens_per_message
for key, value in message.items():
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens
def count_string_tokens(string: str, model_name: str) -> int:
"""
Returns the number of tokens in a text string.
Args:
string (str): The text string.
model_name (str): The name of the encoding to use. (e.g., "gpt-3.5-turbo")
Returns:
int: The number of tokens in the text string.
"""
encoding = tiktoken.encoding_for_model(model_name)
return len(encoding.encode(string))