Merge branch 'main' into fix_typo

This commit is contained in:
Alexander Wu 2024-03-22 11:53:54 +08:00 committed by GitHub
commit 334149bb5d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
34 changed files with 491 additions and 128 deletions

View file

@ -17,6 +17,7 @@ from pydantic import BaseModel, Field, create_model, model_validator
from tenacity import retry, stop_after_attempt, wait_random_exponential
from metagpt.actions.action_outcls_registry import register_action_outcls
from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.llm import BaseLLM
from metagpt.logs import logger
from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess
@ -416,7 +417,7 @@ class ActionNode:
images: Optional[Union[str, list[str]]] = None,
system_msgs: Optional[list[str]] = None,
schema="markdown", # compatible to original format
timeout=3,
timeout=USE_CONFIG_TIMEOUT,
) -> (str, BaseModel):
"""Use ActionOutput to wrap the output of aask"""
content = await self.llm.aask(prompt, system_msgs, images=images, timeout=timeout)
@ -448,7 +449,9 @@ class ActionNode:
def set_context(self, context):
self.set_recursive("context", context)
async def simple_fill(self, schema, mode, images: Optional[Union[str, list[str]]] = None, timeout=3, exclude=None):
async def simple_fill(
self, schema, mode, images: Optional[Union[str, list[str]]] = None, timeout=USE_CONFIG_TIMEOUT, exclude=None
):
prompt = self.compile(context=self.context, schema=schema, mode=mode, exclude=exclude)
if schema != "raw":
@ -473,7 +476,7 @@ class ActionNode:
mode="auto",
strgy="simple",
images: Optional[Union[str, list[str]]] = None,
timeout=3,
timeout=USE_CONFIG_TIMEOUT,
exclude=[],
):
"""Fill the node(s) with mode.

View file

@ -10,6 +10,7 @@ from typing import Optional
from pydantic import field_validator
from metagpt.const import LLM_API_TIMEOUT
from metagpt.utils.yaml_model import YamlModel
@ -74,7 +75,7 @@ class LLMConfig(YamlModel):
stream: bool = False
logprobs: Optional[bool] = None # https://cookbook.openai.com/examples/using_logprobs
top_logprobs: Optional[int] = None
timeout: int = 60
timeout: int = 600
# For Network
proxy: Optional[str] = None
@ -88,3 +89,8 @@ class LLMConfig(YamlModel):
if v in ["", None, "YOUR_API_KEY"]:
raise ValueError("Please set your API key in config2.yaml")
return v
@field_validator("timeout")
@classmethod
def check_timeout(cls, v):
return v or LLM_API_TIMEOUT

View file

@ -123,7 +123,6 @@ BASE64_FORMAT = "base64"
# REDIS
REDIS_KEY = "REDIS_KEY"
LLM_API_TIMEOUT = 300
# Message id
IGNORED_MESSAGE_ID = "0"
@ -132,3 +131,7 @@ IGNORED_MESSAGE_ID = "0"
GENERALIZATION = "Generalize"
COMPOSITION = "Composite"
AGGREGATION = "Aggregate"
# Timeout
USE_CONFIG_TIMEOUT = 0 # Using llm.timeout configuration.
LLM_API_TIMEOUT = 300

View file

@ -4,10 +4,9 @@
from metagpt.environment.base_env import Environment
from metagpt.environment.android_env.android_env import AndroidEnv
from metagpt.environment.minecraft_env.minecraft_env import MinecraftExtEnv
from metagpt.environment.werewolf_env.werewolf_env import WerewolfEnv
from metagpt.environment.stanford_town_env.stanford_town_env import StanfordTownEnv
from metagpt.environment.software_env.software_env import SoftwareEnv
__all__ = ["AndroidEnv", "MinecraftExtEnv", "WerewolfEnv", "StanfordTownEnv", "SoftwareEnv", "Environment"]
__all__ = ["AndroidEnv", "WerewolfEnv", "StanfordTownEnv", "SoftwareEnv", "Environment"]

View file

@ -9,11 +9,11 @@
from pathlib import Path
from typing import Dict, List, Optional
import aiofiles
import yaml
from pydantic import BaseModel, Field
from metagpt.context import Context
from metagpt.utils.common import aread
class Example(BaseModel):
@ -68,8 +68,7 @@ class SkillsDeclaration(BaseModel):
async def load(skill_yaml_file_name: Path = None) -> "SkillsDeclaration":
if not skill_yaml_file_name:
skill_yaml_file_name = Path(__file__).parent.parent.parent / "docs/.well-known/skills.yaml"
async with aiofiles.open(str(skill_yaml_file_name), mode="r") as reader:
data = await reader.read(-1)
data = await aread(filename=skill_yaml_file_name)
skill_data = yaml.safe_load(data)
return SkillsDeclaration(**skill_data)

View file

@ -5,6 +5,7 @@ from anthropic import AsyncAnthropic
from anthropic.types import Message, Usage
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.logs import log_llm_stream
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
@ -41,15 +42,15 @@ class AnthropicLLM(BaseLLM):
def get_choice_text(self, resp: Message) -> str:
return resp.content[0].text
async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> Message:
async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> Message:
resp: Message = await self.aclient.messages.create(**self._const_kwargs(messages))
self._update_costs(resp.usage, self.model)
return resp
async def acompletion(self, messages: list[dict], timeout: int = 3) -> Message:
return await self._achat_completion(messages, timeout=timeout)
async def acompletion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> Message:
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
stream = await self.aclient.messages.create(**self._const_kwargs(messages, stream=True))
collected_content = []
usage = Usage(input_tokens=0, output_tokens=0)

View file

@ -23,6 +23,7 @@ from tenacity import (
)
from metagpt.configs.llm_config import LLMConfig
from metagpt.const import LLM_API_TIMEOUT, USE_CONFIG_TIMEOUT
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.utils.common import log_and_reraise
@ -130,7 +131,7 @@ class BaseLLM(ABC):
system_msgs: Optional[list[str]] = None,
format_msgs: Optional[list[dict[str, str]]] = None,
images: Optional[Union[str, list[str]]] = None,
timeout=3,
timeout=USE_CONFIG_TIMEOUT,
stream=True,
) -> str:
if system_msgs:
@ -146,31 +147,31 @@ class BaseLLM(ABC):
else:
message.extend(msg)
logger.debug(message)
rsp = await self.acompletion_text(message, stream=stream, timeout=timeout)
rsp = await self.acompletion_text(message, stream=stream, timeout=self.get_timeout(timeout))
return rsp
def _extract_assistant_rsp(self, context):
return "\n".join([i["content"] for i in context if i["role"] == "assistant"])
async def aask_batch(self, msgs: list, timeout=3) -> str:
async def aask_batch(self, msgs: list, timeout=USE_CONFIG_TIMEOUT) -> str:
"""Sequential questioning"""
context = []
for msg in msgs:
umsg = self._user_msg(msg)
context.append(umsg)
rsp_text = await self.acompletion_text(context, timeout=timeout)
rsp_text = await self.acompletion_text(context, timeout=self.get_timeout(timeout))
context.append(self._assistant_msg(rsp_text))
return self._extract_assistant_rsp(context)
async def aask_code(self, messages: Union[str, Message, list[dict]], timeout=3, **kwargs) -> dict:
async def aask_code(self, messages: Union[str, Message, list[dict]], timeout=USE_CONFIG_TIMEOUT, **kwargs) -> dict:
raise NotImplementedError
@abstractmethod
async def _achat_completion(self, messages: list[dict], timeout=3):
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
"""_achat_completion implemented by inherited class"""
@abstractmethod
async def acompletion(self, messages: list[dict], timeout=3):
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
"""Asynchronous version of completion
All GPTAPIs are required to provide the standard OpenAI completion interface
[
@ -181,7 +182,7 @@ class BaseLLM(ABC):
"""
@abstractmethod
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
"""_achat_completion_stream implemented by inherited class"""
@retry(
@ -191,11 +192,13 @@ class BaseLLM(ABC):
retry=retry_if_exception_type(ConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(self, messages: list[dict], stream: bool = False, timeout: int = 3) -> str:
async def acompletion_text(
self, messages: list[dict], stream: bool = False, timeout: int = USE_CONFIG_TIMEOUT
) -> str:
"""Asynchronous version of completion. Return str. Support stream-print"""
if stream:
return await self._achat_completion_stream(messages, timeout=timeout)
resp = await self._achat_completion(messages, timeout=timeout)
return await self._achat_completion_stream(messages, timeout=self.get_timeout(timeout))
resp = await self._achat_completion(messages, timeout=self.get_timeout(timeout))
return self.get_choice_text(resp)
def get_choice_text(self, rsp: dict) -> str:
@ -258,3 +261,6 @@ class BaseLLM(ABC):
"""Set model and return self. For example, `with_model("gpt-3.5-turbo")`."""
self.config.model = model
return self
def get_timeout(self, timeout: int) -> int:
return timeout or self.config.timeout or LLM_API_TIMEOUT

View file

@ -25,6 +25,7 @@ from dashscope.common.error import (
UnsupportedApiProtocol,
)
from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.logs import log_llm_stream
from metagpt.provider.base_llm import BaseLLM, LLMConfig
from metagpt.provider.llm_provider_registry import LLMType, register_provider
@ -202,16 +203,16 @@ class DashScopeLLM(BaseLLM):
self._update_costs(dict(resp.usage))
return resp.output
async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> GenerationOutput:
async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> GenerationOutput:
resp: GenerationResponse = await self.aclient.acall(**self._const_kwargs(messages, stream=False))
self._check_response(resp)
self._update_costs(dict(resp.usage))
return resp.output
async def acompletion(self, messages: list[dict], timeout=3) -> GenerationOutput:
return await self._achat_completion(messages, timeout=timeout)
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> GenerationOutput:
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
resp = await self.aclient.acall(**self._const_kwargs(messages, stream=True))
collected_content = []
usage = {}

View file

@ -573,7 +573,7 @@ class APIRequestor:
total=request_timeout[1],
)
else:
timeout = aiohttp.ClientTimeout(total=request_timeout if request_timeout else TIMEOUT_SECS)
timeout = aiohttp.ClientTimeout(total=request_timeout or TIMEOUT_SECS)
if files:
# TODO: Use `aiohttp.MultipartWriter` to create the multipart form data here.

View file

@ -2,6 +2,7 @@
# -*- coding: utf-8 -*-
# @Desc : Google Gemini LLM from https://ai.google.dev/tutorials/python_quickstart
import os
from typing import Optional, Union
import google.generativeai as genai
@ -15,7 +16,8 @@ from google.generativeai.types.generation_types import (
)
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.logs import log_llm_stream
from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.schema import Message
@ -52,6 +54,10 @@ class GeminiLLM(BaseLLM):
self.llm = GeminiGenerativeModel(model_name=self.model)
def __init_gemini(self, config: LLMConfig):
if config.proxy:
logger.info(f"Use proxy: {config.proxy}")
os.environ["HTTP_PROXY"] = config.proxy
os.environ["HTTP_PROXYS"] = config.proxy
genai.configure(api_key=config.api_key)
def _user_msg(self, msg: str, images: Optional[Union[str, list[str]]] = None) -> dict[str, str]:
@ -118,16 +124,18 @@ class GeminiLLM(BaseLLM):
self._update_costs(usage)
return resp
async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> "AsyncGenerateContentResponse":
async def _achat_completion(
self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT
) -> "AsyncGenerateContentResponse":
resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(**self._const_kwargs(messages))
usage = await self.aget_usage(messages, resp.text)
self._update_costs(usage)
return resp
async def acompletion(self, messages: list[dict], timeout=3) -> dict:
return await self._achat_completion(messages, timeout=timeout)
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(
**self._const_kwargs(messages, stream=True)
)

View file

@ -6,6 +6,7 @@ Author: garylin2099
from typing import Optional
from metagpt.configs.llm_config import LLMConfig
from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.logs import logger
from metagpt.provider.base_llm import BaseLLM
@ -18,7 +19,7 @@ class HumanProvider(BaseLLM):
def __init__(self, config: LLMConfig):
pass
def ask(self, msg: str, timeout=3) -> str:
def ask(self, msg: str, timeout=USE_CONFIG_TIMEOUT) -> str:
logger.info("It's your turn, please type in your response. You may also refer to the context below")
rsp = input(msg)
if rsp in ["exit", "quit"]:
@ -31,20 +32,20 @@ class HumanProvider(BaseLLM):
system_msgs: Optional[list[str]] = None,
format_msgs: Optional[list[dict[str, str]]] = None,
generator: bool = False,
timeout=3,
timeout=USE_CONFIG_TIMEOUT,
) -> str:
return self.ask(msg, timeout=timeout)
return self.ask(msg, timeout=self.get_timeout(timeout))
async def _achat_completion(self, messages: list[dict], timeout=3):
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
pass
async def acompletion(self, messages: list[dict], timeout=3):
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
"""dummy implementation of abstract method in base"""
return []
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
pass
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
async def acompletion_text(self, messages: list[dict], stream=False, timeout=USE_CONFIG_TIMEOUT) -> str:
"""dummy implementation of abstract method in base"""
return ""

View file

@ -5,7 +5,7 @@
import json
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.const import LLM_API_TIMEOUT
from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.logs import log_llm_stream
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.general_api_requestor import GeneralAPIRequestor
@ -50,28 +50,28 @@ class OllamaLLM(BaseLLM):
chunk = chunk.decode(encoding)
return json.loads(chunk)
async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> dict:
async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> dict:
resp, _, _ = await self.client.arequest(
method=self.http_method,
url=self.suffix_url,
params=self._const_kwargs(messages),
request_timeout=LLM_API_TIMEOUT,
request_timeout=self.get_timeout(timeout),
)
resp = self._decode_and_load(resp)
usage = self.get_usage(resp)
self._update_costs(usage)
return resp
async def acompletion(self, messages: list[dict], timeout=3) -> dict:
return await self._achat_completion(messages, timeout=timeout)
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
stream_resp, _, _ = await self.client.arequest(
method=self.http_method,
url=self.suffix_url,
stream=True,
params=self._const_kwargs(messages, stream=True),
request_timeout=LLM_API_TIMEOUT,
request_timeout=self.get_timeout(timeout),
)
collected_content = []

View file

@ -25,6 +25,7 @@ from tenacity import (
)
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA
@ -74,9 +75,9 @@ class OpenAILLM(BaseLLM):
return params
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str:
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
**self._cons_kwargs(messages, timeout=timeout), stream=True
**self._cons_kwargs(messages, timeout=self.get_timeout(timeout)), stream=True
)
usage = None
collected_messages = []
@ -104,7 +105,7 @@ class OpenAILLM(BaseLLM):
self._update_costs(usage)
return full_reply_content
def _cons_kwargs(self, messages: list[dict], timeout=3, **extra_kwargs) -> dict:
def _cons_kwargs(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT, **extra_kwargs) -> dict:
kwargs = {
"messages": messages,
"max_tokens": self._get_max_tokens(messages),
@ -112,20 +113,20 @@ class OpenAILLM(BaseLLM):
# "stop": None, # default it's None and gpt4-v can't have this one
"temperature": self.config.temperature,
"model": self.model,
"timeout": max(self.config.timeout, timeout),
"timeout": self.get_timeout(timeout),
}
if extra_kwargs:
kwargs.update(extra_kwargs)
return kwargs
async def _achat_completion(self, messages: list[dict], timeout=3) -> ChatCompletion:
kwargs = self._cons_kwargs(messages, timeout=timeout)
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> ChatCompletion:
kwargs = self._cons_kwargs(messages, timeout=self.get_timeout(timeout))
rsp: ChatCompletion = await self.aclient.chat.completions.create(**kwargs)
self._update_costs(rsp.usage)
return rsp
async def acompletion(self, messages: list[dict], timeout=3) -> ChatCompletion:
return await self._achat_completion(messages, timeout=timeout)
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> ChatCompletion:
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
@retry(
wait=wait_random_exponential(min=1, max=60),
@ -134,24 +135,24 @@ class OpenAILLM(BaseLLM):
retry=retry_if_exception_type(APIConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
async def acompletion_text(self, messages: list[dict], stream=False, timeout=USE_CONFIG_TIMEOUT) -> str:
"""when streaming, print each token in place."""
if stream:
return await self._achat_completion_stream(messages, timeout=timeout)
rsp = await self._achat_completion(messages, timeout=timeout)
rsp = await self._achat_completion(messages, timeout=self.get_timeout(timeout))
return self.get_choice_text(rsp)
async def _achat_completion_function(
self, messages: list[dict], timeout: int = 3, **chat_configs
self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT, **chat_configs
) -> ChatCompletion:
messages = self.format_msg(messages)
kwargs = self._cons_kwargs(messages=messages, timeout=timeout, **chat_configs)
kwargs = self._cons_kwargs(messages=messages, timeout=self.get_timeout(timeout), **chat_configs)
rsp: ChatCompletion = await self.aclient.chat.completions.create(**kwargs)
self._update_costs(rsp.usage)
return rsp
async def aask_code(self, messages: list[dict], timeout: int = 3, **kwargs) -> dict:
async def aask_code(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT, **kwargs) -> dict:
"""Use function of tools to ask a code.
Note: Keep kwargs consistent with https://platform.openai.com/docs/api-reference/chat/create

View file

@ -9,6 +9,7 @@ from qianfan import ChatCompletion
from qianfan.resources.typing import JsonBody
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.logs import log_llm_stream
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
@ -107,15 +108,15 @@ class QianFanLLM(BaseLLM):
self._update_costs(resp.body.get("usage", {}))
return resp.body
async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> JsonBody:
async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> JsonBody:
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=False))
self._update_costs(resp.body.get("usage", {}))
return resp.body
async def acompletion(self, messages: list[dict], timeout: int = 3) -> JsonBody:
return await self._achat_completion(messages, timeout=timeout)
async def acompletion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> JsonBody:
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=True))
collected_content = []
usage = {}

View file

@ -17,6 +17,7 @@ from wsgiref.handlers import format_date_time
import websocket # 使用websocket_client
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.logs import logger
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
@ -31,19 +32,19 @@ class SparkLLM(BaseLLM):
def get_choice_text(self, rsp: dict) -> str:
return rsp["payload"]["choices"]["text"][-1]["content"]
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
pass
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = USE_CONFIG_TIMEOUT) -> str:
# 不支持
# logger.warning("当前方法无法支持异步运行。当你使用acompletion时并不能并行访问。")
w = GetMessageFromWeb(messages, self.config)
return w.run()
async def _achat_completion(self, messages: list[dict], timeout=3):
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
pass
async def acompletion(self, messages: list[dict], timeout=3):
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
# 不支持异步
w = GetMessageFromWeb(messages, self.config)
return w.run()

View file

@ -8,6 +8,7 @@ from typing import Optional
from zhipuai.types.chat.chat_completion import Completion
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.logs import log_llm_stream
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
@ -45,22 +46,22 @@ class ZhiPuAILLM(BaseLLM):
kwargs = {"model": self.model, "messages": messages, "stream": stream, "temperature": 0.3}
return kwargs
def completion(self, messages: list[dict], timeout=3) -> dict:
def completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
resp: Completion = self.llm.chat.completions.create(**self._const_kwargs(messages))
usage = resp.usage.model_dump()
self._update_costs(usage)
return resp.model_dump()
async def _achat_completion(self, messages: list[dict], timeout=3) -> dict:
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
resp = await self.llm.acreate(**self._const_kwargs(messages))
usage = resp.get("usage", {})
self._update_costs(usage)
return resp
async def acompletion(self, messages: list[dict], timeout=3) -> dict:
return await self._achat_completion(messages, timeout=timeout)
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str:
response = await self.llm.acreate_stream(**self._const_kwargs(messages, stream=True))
collected_content = []
usage = {}

View file

@ -14,7 +14,6 @@ from metagpt.roles.engineer import Engineer
from metagpt.roles.qa_engineer import QaEngineer
from metagpt.roles.searcher import Searcher
from metagpt.roles.sales import Sales
from metagpt.roles.customer_service import CustomerService
__all__ = [
@ -26,5 +25,4 @@ __all__ = [
"QaEngineer",
"Searcher",
"Sales",
"CustomerService",
]

View file

@ -18,6 +18,7 @@ import csv
import importlib
import inspect
import json
import mimetypes
import os
import platform
import re
@ -29,6 +30,7 @@ from typing import Any, Callable, List, Literal, Tuple, Union
from urllib.parse import quote, unquote
import aiofiles
import chardet
import loguru
import requests
from PIL import Image
@ -663,14 +665,21 @@ def role_raise_decorator(func):
@handle_exception
async def aread(filename: str | Path, encoding=None) -> str:
async def aread(filename: str | Path, encoding="utf-8") -> str:
"""Read file asynchronously."""
async with aiofiles.open(str(filename), mode="r", encoding=encoding) as reader:
content = await reader.read()
try:
async with aiofiles.open(str(filename), mode="r", encoding=encoding) as reader:
content = await reader.read()
except UnicodeDecodeError:
async with aiofiles.open(str(filename), mode="rb") as reader:
raw = await reader.read()
result = chardet.detect(raw)
detected_encoding = result["encoding"]
content = raw.decode(detected_encoding)
return content
async def awrite(filename: str | Path, data: str, encoding=None):
async def awrite(filename: str | Path, data: str, encoding="utf-8"):
"""Write file asynchronously."""
pathname = Path(filename)
pathname.parent.mkdir(parents=True, exist_ok=True)
@ -811,3 +820,21 @@ See FAQ 5.8
"""
)
raise retry_state.outcome.exception()
def get_markdown_codeblock_type(filename: str) -> str:
"""Return the markdown code-block type corresponding to the file extension."""
mime_type, _ = mimetypes.guess_type(filename)
mappings = {
"text/x-shellscript": "bash",
"text/x-c++src": "cpp",
"text/css": "css",
"text/html": "html",
"text/x-java": "java",
"application/javascript": "javascript",
"application/json": "json",
"text/x-python": "python",
"text/x-ruby": "ruby",
"application/sql": "sql",
}
return mappings.get(mime_type, "text")

View file

@ -13,9 +13,7 @@ import re
from pathlib import Path
from typing import Set
import aiofiles
from metagpt.utils.common import aread
from metagpt.utils.common import aread, awrite
from metagpt.utils.exceptions import handle_exception
@ -45,8 +43,7 @@ class DependencyFile:
async def save(self):
"""Save dependencies to the file asynchronously."""
data = json.dumps(self._dependencies)
async with aiofiles.open(str(self._filename), mode="w") as writer:
await writer.write(data)
await awrite(filename=self._filename, data=data)
async def update(self, filename: Path | str, dependencies: Set[Path | str], persist=True):
"""Update dependencies for a file asynchronously.

View file

@ -14,11 +14,9 @@ from datetime import datetime
from pathlib import Path
from typing import Dict, List, Set
import aiofiles
from metagpt.logs import logger
from metagpt.schema import Document
from metagpt.utils.common import aread
from metagpt.utils.common import aread, awrite
from metagpt.utils.json_to_markdown import json_to_markdown
@ -55,8 +53,7 @@ class FileRepository:
pathname = self.workdir / filename
pathname.parent.mkdir(parents=True, exist_ok=True)
content = content if content else "" # avoid `argument must be str, not None` to make it continue
async with aiofiles.open(str(pathname), mode="w") as writer:
await writer.write(content)
await awrite(filename=str(pathname), data=content)
logger.info(f"save to: {str(pathname)}")
if dependencies is not None:

View file

@ -9,11 +9,9 @@ import asyncio
import os
from pathlib import Path
import aiofiles
from metagpt.config2 import config
from metagpt.logs import logger
from metagpt.utils.common import check_cmd_exists
from metagpt.utils.common import awrite, check_cmd_exists
async def mermaid_to_file(engine, mermaid_code, output_file_without_suffix, width=2048, height=2048) -> int:
@ -30,9 +28,7 @@ async def mermaid_to_file(engine, mermaid_code, output_file_without_suffix, widt
if dir_name and not os.path.exists(dir_name):
os.makedirs(dir_name)
tmp = Path(f"{output_file_without_suffix}.mmd")
async with aiofiles.open(tmp, "w", encoding="utf-8") as f:
await f.write(mermaid_code)
# tmp.write_text(mermaid_code, encoding="utf-8")
await awrite(filename=tmp, data=mermaid_code)
if engine == "nodejs":
if check_cmd_exists(config.mermaid.path) != 0:

View file

@ -340,7 +340,9 @@ def extract_state_value_from_output(content: str) -> str:
content (str): llm's output from `Role._think`
"""
content = content.strip() # deal the output cases like " 0", "0\n" and so on.
pattern = r"([0-9])" # TODO find the number using a more proper method not just extract from content using pattern
pattern = (
r"(?<!-)[0-9]" # TODO find the number using a more proper method not just extract from content using pattern
)
matches = re.findall(pattern, content, re.DOTALL)
matches = list(set(matches))
state = matches[0] if len(matches) > 0 else "-1"

View file

@ -0,0 +1,80 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
This file provides functionality to convert a local repository into a markdown representation.
"""
from __future__ import annotations
import mimetypes
from pathlib import Path
from gitignore_parser import parse_gitignore
from metagpt.logs import logger
from metagpt.utils.common import aread, awrite, get_markdown_codeblock_type, list_files
from metagpt.utils.tree import tree
async def repo_to_markdown(repo_path: str | Path, output: str | Path = None, gitignore: str | Path = None) -> str:
"""
Convert a local repository into a markdown representation.
This function takes a path to a local repository and generates a markdown representation of the repository structure,
including directory trees and file listings.
Args:
repo_path (str | Path): The path to the local repository.
output (str | Path, optional): The path to save the generated markdown file. Defaults to None.
gitignore (str | Path, optional): The path to the .gitignore file. Defaults to None.
Returns:
str: The markdown representation of the repository.
"""
repo_path = Path(repo_path)
gitignore = Path(gitignore or Path(__file__).parent / "../../.gitignore").resolve()
markdown = await _write_dir_tree(repo_path=repo_path, gitignore=gitignore)
gitignore_rules = parse_gitignore(full_path=str(gitignore))
markdown += await _write_files(repo_path=repo_path, gitignore_rules=gitignore_rules)
if output:
await awrite(filename=str(output), data=markdown, encoding="utf-8")
return markdown
async def _write_dir_tree(repo_path: Path, gitignore: Path) -> str:
try:
content = tree(repo_path, gitignore, run_command=True)
except Exception as e:
logger.info(f"{e}, using safe mode.")
content = tree(repo_path, gitignore, run_command=False)
doc = f"## Directory Tree\n```text\n{content}\n```\n---\n\n"
return doc
async def _write_files(repo_path, gitignore_rules) -> str:
filenames = list_files(repo_path)
markdown = ""
for filename in filenames:
if gitignore_rules(str(filename)):
continue
markdown += await _write_file(filename=filename, repo_path=repo_path)
return markdown
async def _write_file(filename: Path, repo_path: Path) -> str:
relative_path = filename.relative_to(repo_path)
markdown = f"## {relative_path}\n"
mime_type, _ = mimetypes.guess_type(filename.name)
if "text/" not in mime_type:
logger.info(f"Ignore content: {filename}")
markdown += "<binary file>\n---\n\n"
return markdown
content = await aread(filename, encoding="utf-8")
content = content.replace("```", "\\`\\`\\`").replace("---", "\\-\\-\\-")
code_block_type = get_markdown_codeblock_type(filename.name)
markdown += f"```{code_block_type}\n{content}\n```\n---\n\n"
return markdown

140
metagpt/utils/tree.py Normal file
View file

@ -0,0 +1,140 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/3/11
@Author : mashenquan
@File : tree.py
@Desc : Implement the same functionality as the `tree` command.
Example:
>>> print_tree(".")
utils
+-- serialize.py
+-- project_repo.py
+-- tree.py
+-- mmdc_playwright.py
+-- cost_manager.py
+-- __pycache__
| +-- __init__.cpython-39.pyc
| +-- redis.cpython-39.pyc
| +-- singleton.cpython-39.pyc
| +-- embedding.cpython-39.pyc
| +-- make_sk_kernel.cpython-39.pyc
| +-- file_repository.cpython-39.pyc
+-- file.py
+-- save_code.py
+-- common.py
+-- redis.py
"""
from __future__ import annotations
import subprocess
from pathlib import Path
from typing import Callable, Dict, List
from gitignore_parser import parse_gitignore
def tree(root: str | Path, gitignore: str | Path = None, run_command: bool = False) -> str:
"""
Recursively traverses the directory structure and prints it out in a tree-like format.
Args:
root (str or Path): The root directory from which to start traversing.
gitignore (str or Path): The filename of gitignore file.
run_command (bool): Whether to execute `tree` command. Execute the `tree` command and return the result if True,
otherwise execute python code instead.
Returns:
str: A string representation of the directory tree.
Example:
>>> tree(".")
utils
+-- serialize.py
+-- project_repo.py
+-- tree.py
+-- mmdc_playwright.py
+-- __pycache__
| +-- __init__.cpython-39.pyc
| +-- redis.cpython-39.pyc
| +-- singleton.cpython-39.pyc
+-- parse_docstring.py
>>> tree(".", gitignore="../../.gitignore")
utils
+-- serialize.py
+-- project_repo.py
+-- tree.py
+-- mmdc_playwright.py
+-- parse_docstring.py
>>> tree(".", gitignore="../../.gitignore", run_command=True)
utils
serialize.py
project_repo.py
tree.py
mmdc_playwright.py
parse_docstring.py
"""
root = Path(root).resolve()
if run_command:
return _execute_tree(root, gitignore)
git_ignore_rules = parse_gitignore(gitignore) if gitignore else None
dir_ = {root.name: _list_children(root=root, git_ignore_rules=git_ignore_rules)}
v = _print_tree(dir_)
return "\n".join(v)
def _list_children(root: Path, git_ignore_rules: Callable) -> Dict[str, Dict]:
dir_ = {}
for i in root.iterdir():
if git_ignore_rules and git_ignore_rules(str(i)):
continue
try:
if i.is_file():
dir_[i.name] = {}
else:
dir_[i.name] = _list_children(root=i, git_ignore_rules=git_ignore_rules)
except (FileNotFoundError, PermissionError, OSError):
dir_[i.name] = {}
return dir_
def _print_tree(dir_: Dict[str:Dict]) -> List[str]:
ret = []
for name, children in dir_.items():
ret.append(name)
if not children:
continue
lines = _print_tree(children)
for j, v in enumerate(lines):
if v[0] not in ["+", " ", "|"]:
ret = _add_line(ret)
row = f"+-- {v}"
else:
row = f" {v}"
ret.append(row)
return ret
def _add_line(rows: List[str]) -> List[str]:
for i in range(len(rows) - 1, -1, -1):
v = rows[i]
if v[0] != " ":
return rows
rows[i] = "|" + v[1:]
return rows
def _execute_tree(root: Path, gitignore: str | Path) -> str:
args = ["--gitfile", str(gitignore)] if gitignore else []
try:
result = subprocess.run(["tree"] + args + [str(root)], capture_output=True, text=True, check=True)
if result.returncode != 0:
raise ValueError(f"tree exits with code {result.returncode}")
return result.stdout
except subprocess.CalledProcessError as e:
raise e