Merge branch 'geekan:main' into main

This commit is contained in:
brucemeek 2023-08-17 11:02:27 -05:00 committed by GitHub
commit a0e6d20034
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
50 changed files with 1734 additions and 250 deletions

View file

@ -13,3 +13,12 @@ from metagpt.utils.token_counter import (
count_message_tokens,
count_string_tokens,
)
__all__ = [
"read_docx",
"Singleton",
"TOKEN_COSTS",
"count_message_tokens",
"count_string_tokens",
]

View file

@ -0,0 +1,57 @@
#!/usr/bin/env python
from __future__ import annotations
from typing import Generator, Optional
from urllib.parse import urljoin, urlparse
from bs4 import BeautifulSoup
from pydantic import BaseModel
class WebPage(BaseModel):
inner_text: str
html: str
url: str
class Config:
underscore_attrs_are_private = True
_soup : Optional[BeautifulSoup] = None
_title: Optional[str] = None
@property
def soup(self) -> BeautifulSoup:
if self._soup is None:
self._soup = BeautifulSoup(self.html, "html.parser")
return self._soup
@property
def title(self):
if self._title is None:
title_tag = self.soup.find("title")
self._title = title_tag.text.strip() if title_tag is not None else ""
return self._title
def get_links(self) -> Generator[str, None, None]:
for i in self.soup.find_all("a", href=True):
url = i["href"]
result = urlparse(url)
if not result.scheme and result.path:
yield urljoin(self.url, url)
elif url.startswith(("http://", "https://")):
yield urljoin(self.url, url)
def get_html_content(page: str, base: str):
soup = _get_soup(page)
return soup.get_text(strip=True)
def _get_soup(page: str):
soup = BeautifulSoup(page, "html.parser")
# https://stackoverflow.com/questions/1936466/how-to-scrape-only-visible-webpage-text-with-beautifulsoup
for s in soup(["style", "script", "[document]", "head", "title"]):
s.extract()
return soup

View file

@ -3,14 +3,11 @@
# @Desc : the implement of serialization and deserialization
import copy
from typing import Tuple, List, Type, Union, Dict
import pickle
from collections import defaultdict
from pydantic import create_model
from typing import Dict, List, Tuple
from metagpt.schema import Message
from metagpt.actions.action import Action
from metagpt.actions.action_output import ActionOutput
from metagpt.schema import Message
def actionoutout_schema_to_mapping(schema: Dict) -> Dict:
@ -34,12 +31,12 @@ def actionoutout_schema_to_mapping(schema: Dict) -> Dict:
```
"""
mapping = dict()
for field, property in schema['properties'].items():
if property['type'] == 'string':
for field, property in schema["properties"].items():
if property["type"] == "string":
mapping[field] = (str, ...)
elif property['type'] == 'array' and property['items']['type'] == 'string':
elif property["type"] == "array" and property["items"]["type"] == "string":
mapping[field] = (List[str], ...)
elif property['type'] == 'array' and property['items']['type'] == 'array':
elif property["type"] == "array" and property["items"]["type"] == "array":
# here only consider the `Tuple[str, str]` situation
mapping[field] = (List[Tuple[str, str]], ...)
return mapping
@ -53,11 +50,7 @@ def serialize_message(message: Message):
schema = ic.schema()
mapping = actionoutout_schema_to_mapping(schema)
message_cp.instruct_content = {
'class': schema['title'],
'mapping': mapping,
'value': ic.dict()
}
message_cp.instruct_content = {"class": schema["title"], "mapping": mapping, "value": ic.dict()}
msg_ser = pickle.dumps(message_cp)
return msg_ser
@ -67,9 +60,8 @@ def deserialize_message(message_ser: str) -> Message:
message = pickle.loads(message_ser)
if message.instruct_content:
ic = message.instruct_content
ic_obj = ActionOutput.create_model_class(class_name=ic['class'],
mapping=ic['mapping'])
ic_new = ic_obj(**ic['value'])
ic_obj = ActionOutput.create_model_class(class_name=ic["class"], mapping=ic["mapping"])
ic_new = ic_obj(**ic["value"])
message.instruct_content = ic_new
return message

124
metagpt/utils/text.py Normal file
View file

@ -0,0 +1,124 @@
from typing import Generator, Sequence
from metagpt.utils.token_counter import TOKEN_MAX, count_string_tokens
def reduce_message_length(msgs: Generator[str, None, None], model_name: str, system_text: str, reserved: int = 0,) -> str:
"""Reduce the length of concatenated message segments to fit within the maximum token size.
Args:
msgs: A generator of strings representing progressively shorter valid prompts.
model_name: The name of the encoding to use. (e.g., "gpt-3.5-turbo")
system_text: The system prompts.
reserved: The number of reserved tokens.
Returns:
The concatenated message segments reduced to fit within the maximum token size.
Raises:
RuntimeError: If it fails to reduce the concatenated message length.
"""
max_token = TOKEN_MAX.get(model_name, 2048) - count_string_tokens(system_text, model_name) - reserved
for msg in msgs:
if count_string_tokens(msg, model_name) < max_token:
return msg
raise RuntimeError("fail to reduce message length")
def generate_prompt_chunk(
text: str,
prompt_template: str,
model_name: str,
system_text: str,
reserved: int = 0,
) -> Generator[str, None, None]:
"""Split the text into chunks of a maximum token size.
Args:
text: The text to split.
prompt_template: The template for the prompt, containing a single `{}` placeholder. For example, "### Reference\n{}".
model_name: The name of the encoding to use. (e.g., "gpt-3.5-turbo")
system_text: The system prompts.
reserved: The number of reserved tokens.
Yields:
The chunk of text.
"""
paragraphs = text.splitlines(keepends=True)
current_token = 0
current_lines = []
reserved = reserved + count_string_tokens(prompt_template+system_text, model_name)
# 100 is a magic number to ensure the maximum context length is not exceeded
max_token = TOKEN_MAX.get(model_name, 2048) - reserved - 100
while paragraphs:
paragraph = paragraphs.pop(0)
token = count_string_tokens(paragraph, model_name)
if current_token + token <= max_token:
current_lines.append(paragraph)
current_token += token
elif token > max_token:
paragraphs = split_paragraph(paragraph) + paragraphs
continue
else:
yield prompt_template.format("".join(current_lines))
current_lines = [paragraph]
current_token = token
if current_lines:
yield prompt_template.format("".join(current_lines))
def split_paragraph(paragraph: str, sep: str = ".,", count: int = 2) -> list[str]:
"""Split a paragraph into multiple parts.
Args:
paragraph: The paragraph to split.
sep: The separator character.
count: The number of parts to split the paragraph into.
Returns:
A list of split parts of the paragraph.
"""
for i in sep:
sentences = list(_split_text_with_ends(paragraph, i))
if len(sentences) <= 1:
continue
ret = ["".join(j) for j in _split_by_count(sentences, count)]
return ret
return _split_by_count(paragraph, count)
def decode_unicode_escape(text: str) -> str:
"""Decode a text with unicode escape sequences.
Args:
text: The text to decode.
Returns:
The decoded text.
"""
return text.encode("utf-8").decode("unicode_escape", "ignore")
def _split_by_count(lst: Sequence , count: int):
avg = len(lst) // count
remainder = len(lst) % count
start = 0
for i in range(count):
end = start + avg + (1 if i < remainder else 0)
yield lst[start:end]
start = end
def _split_text_with_ends(text: str, sep: str = "."):
parts = []
for i in text:
parts.append(i)
if i == sep:
yield "".join(parts)
parts = []
if parts:
yield "".join(parts)

View file

@ -25,6 +25,21 @@ TOKEN_COSTS = {
}
TOKEN_MAX = {
"gpt-3.5-turbo": 4096,
"gpt-3.5-turbo-0301": 4096,
"gpt-3.5-turbo-0613": 4096,
"gpt-3.5-turbo-16k": 16384,
"gpt-3.5-turbo-16k-0613": 16384,
"gpt-4-0314": 8192,
"gpt-4": 8192,
"gpt-4-32k": 32768,
"gpt-4-32k-0314": 32768,
"gpt-4-0613": 8192,
"text-embedding-ada-002": 8192,
}
def count_message_tokens(messages, model="gpt-3.5-turbo-0613"):
"""Return the number of tokens used by a list of messages."""
try:
@ -39,7 +54,7 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"):
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613",
}:
}:
tokens_per_message = 3
tokens_per_name = 1
elif model == "gpt-3.5-turbo-0301":
@ -79,3 +94,18 @@ def count_string_tokens(string: str, model_name: str) -> int:
"""
encoding = tiktoken.encoding_for_model(model_name)
return len(encoding.encode(string))
def get_max_completion_tokens(messages: list[dict], model: str, default: int) -> int:
"""Calculate the maximum number of completion tokens for a given model and list of messages.
Args:
messages: A list of messages.
model: The model name.
Returns:
The maximum number of completion tokens.
"""
if model not in TOKEN_MAX:
return default
return TOKEN_MAX[model] - count_message_tokens(messages)