mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-11 16:52:37 +02:00
Merge branch 'main' into fix_typo
This commit is contained in:
commit
334149bb5d
34 changed files with 491 additions and 128 deletions
|
|
@ -17,6 +17,7 @@ from pydantic import BaseModel, Field, create_model, model_validator
|
|||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
|
||||
from metagpt.actions.action_outcls_registry import register_action_outcls
|
||||
from metagpt.const import USE_CONFIG_TIMEOUT
|
||||
from metagpt.llm import BaseLLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess
|
||||
|
|
@ -416,7 +417,7 @@ class ActionNode:
|
|||
images: Optional[Union[str, list[str]]] = None,
|
||||
system_msgs: Optional[list[str]] = None,
|
||||
schema="markdown", # compatible to original format
|
||||
timeout=3,
|
||||
timeout=USE_CONFIG_TIMEOUT,
|
||||
) -> (str, BaseModel):
|
||||
"""Use ActionOutput to wrap the output of aask"""
|
||||
content = await self.llm.aask(prompt, system_msgs, images=images, timeout=timeout)
|
||||
|
|
@ -448,7 +449,9 @@ class ActionNode:
|
|||
def set_context(self, context):
|
||||
self.set_recursive("context", context)
|
||||
|
||||
async def simple_fill(self, schema, mode, images: Optional[Union[str, list[str]]] = None, timeout=3, exclude=None):
|
||||
async def simple_fill(
|
||||
self, schema, mode, images: Optional[Union[str, list[str]]] = None, timeout=USE_CONFIG_TIMEOUT, exclude=None
|
||||
):
|
||||
prompt = self.compile(context=self.context, schema=schema, mode=mode, exclude=exclude)
|
||||
|
||||
if schema != "raw":
|
||||
|
|
@ -473,7 +476,7 @@ class ActionNode:
|
|||
mode="auto",
|
||||
strgy="simple",
|
||||
images: Optional[Union[str, list[str]]] = None,
|
||||
timeout=3,
|
||||
timeout=USE_CONFIG_TIMEOUT,
|
||||
exclude=[],
|
||||
):
|
||||
"""Fill the node(s) with mode.
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from typing import Optional
|
|||
|
||||
from pydantic import field_validator
|
||||
|
||||
from metagpt.const import LLM_API_TIMEOUT
|
||||
from metagpt.utils.yaml_model import YamlModel
|
||||
|
||||
|
||||
|
|
@ -74,7 +75,7 @@ class LLMConfig(YamlModel):
|
|||
stream: bool = False
|
||||
logprobs: Optional[bool] = None # https://cookbook.openai.com/examples/using_logprobs
|
||||
top_logprobs: Optional[int] = None
|
||||
timeout: int = 60
|
||||
timeout: int = 600
|
||||
|
||||
# For Network
|
||||
proxy: Optional[str] = None
|
||||
|
|
@ -88,3 +89,8 @@ class LLMConfig(YamlModel):
|
|||
if v in ["", None, "YOUR_API_KEY"]:
|
||||
raise ValueError("Please set your API key in config2.yaml")
|
||||
return v
|
||||
|
||||
@field_validator("timeout")
|
||||
@classmethod
|
||||
def check_timeout(cls, v):
|
||||
return v or LLM_API_TIMEOUT
|
||||
|
|
|
|||
|
|
@ -123,7 +123,6 @@ BASE64_FORMAT = "base64"
|
|||
|
||||
# REDIS
|
||||
REDIS_KEY = "REDIS_KEY"
|
||||
LLM_API_TIMEOUT = 300
|
||||
|
||||
# Message id
|
||||
IGNORED_MESSAGE_ID = "0"
|
||||
|
|
@ -132,3 +131,7 @@ IGNORED_MESSAGE_ID = "0"
|
|||
GENERALIZATION = "Generalize"
|
||||
COMPOSITION = "Composite"
|
||||
AGGREGATION = "Aggregate"
|
||||
|
||||
# Timeout
|
||||
USE_CONFIG_TIMEOUT = 0 # Using llm.timeout configuration.
|
||||
LLM_API_TIMEOUT = 300
|
||||
|
|
|
|||
|
|
@ -4,10 +4,9 @@
|
|||
|
||||
from metagpt.environment.base_env import Environment
|
||||
from metagpt.environment.android_env.android_env import AndroidEnv
|
||||
from metagpt.environment.minecraft_env.minecraft_env import MinecraftExtEnv
|
||||
from metagpt.environment.werewolf_env.werewolf_env import WerewolfEnv
|
||||
from metagpt.environment.stanford_town_env.stanford_town_env import StanfordTownEnv
|
||||
from metagpt.environment.software_env.software_env import SoftwareEnv
|
||||
|
||||
|
||||
__all__ = ["AndroidEnv", "MinecraftExtEnv", "WerewolfEnv", "StanfordTownEnv", "SoftwareEnv", "Environment"]
|
||||
__all__ = ["AndroidEnv", "WerewolfEnv", "StanfordTownEnv", "SoftwareEnv", "Environment"]
|
||||
|
|
|
|||
|
|
@ -9,11 +9,11 @@
|
|||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import aiofiles
|
||||
import yaml
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.context import Context
|
||||
from metagpt.utils.common import aread
|
||||
|
||||
|
||||
class Example(BaseModel):
|
||||
|
|
@ -68,8 +68,7 @@ class SkillsDeclaration(BaseModel):
|
|||
async def load(skill_yaml_file_name: Path = None) -> "SkillsDeclaration":
|
||||
if not skill_yaml_file_name:
|
||||
skill_yaml_file_name = Path(__file__).parent.parent.parent / "docs/.well-known/skills.yaml"
|
||||
async with aiofiles.open(str(skill_yaml_file_name), mode="r") as reader:
|
||||
data = await reader.read(-1)
|
||||
data = await aread(filename=skill_yaml_file_name)
|
||||
skill_data = yaml.safe_load(data)
|
||||
return SkillsDeclaration(**skill_data)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from anthropic import AsyncAnthropic
|
|||
from anthropic.types import Message, Usage
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.const import USE_CONFIG_TIMEOUT
|
||||
from metagpt.logs import log_llm_stream
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
|
|
@ -41,15 +42,15 @@ class AnthropicLLM(BaseLLM):
|
|||
def get_choice_text(self, resp: Message) -> str:
|
||||
return resp.content[0].text
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> Message:
|
||||
async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> Message:
|
||||
resp: Message = await self.aclient.messages.create(**self._const_kwargs(messages))
|
||||
self._update_costs(resp.usage, self.model)
|
||||
return resp
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout: int = 3) -> Message:
|
||||
return await self._achat_completion(messages, timeout=timeout)
|
||||
async def acompletion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> Message:
|
||||
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
|
||||
stream = await self.aclient.messages.create(**self._const_kwargs(messages, stream=True))
|
||||
collected_content = []
|
||||
usage = Usage(input_tokens=0, output_tokens=0)
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from tenacity import (
|
|||
)
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig
|
||||
from metagpt.const import LLM_API_TIMEOUT, USE_CONFIG_TIMEOUT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import log_and_reraise
|
||||
|
|
@ -130,7 +131,7 @@ class BaseLLM(ABC):
|
|||
system_msgs: Optional[list[str]] = None,
|
||||
format_msgs: Optional[list[dict[str, str]]] = None,
|
||||
images: Optional[Union[str, list[str]]] = None,
|
||||
timeout=3,
|
||||
timeout=USE_CONFIG_TIMEOUT,
|
||||
stream=True,
|
||||
) -> str:
|
||||
if system_msgs:
|
||||
|
|
@ -146,31 +147,31 @@ class BaseLLM(ABC):
|
|||
else:
|
||||
message.extend(msg)
|
||||
logger.debug(message)
|
||||
rsp = await self.acompletion_text(message, stream=stream, timeout=timeout)
|
||||
rsp = await self.acompletion_text(message, stream=stream, timeout=self.get_timeout(timeout))
|
||||
return rsp
|
||||
|
||||
def _extract_assistant_rsp(self, context):
|
||||
return "\n".join([i["content"] for i in context if i["role"] == "assistant"])
|
||||
|
||||
async def aask_batch(self, msgs: list, timeout=3) -> str:
|
||||
async def aask_batch(self, msgs: list, timeout=USE_CONFIG_TIMEOUT) -> str:
|
||||
"""Sequential questioning"""
|
||||
context = []
|
||||
for msg in msgs:
|
||||
umsg = self._user_msg(msg)
|
||||
context.append(umsg)
|
||||
rsp_text = await self.acompletion_text(context, timeout=timeout)
|
||||
rsp_text = await self.acompletion_text(context, timeout=self.get_timeout(timeout))
|
||||
context.append(self._assistant_msg(rsp_text))
|
||||
return self._extract_assistant_rsp(context)
|
||||
|
||||
async def aask_code(self, messages: Union[str, Message, list[dict]], timeout=3, **kwargs) -> dict:
|
||||
async def aask_code(self, messages: Union[str, Message, list[dict]], timeout=USE_CONFIG_TIMEOUT, **kwargs) -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def _achat_completion(self, messages: list[dict], timeout=3):
|
||||
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
|
||||
"""_achat_completion implemented by inherited class"""
|
||||
|
||||
@abstractmethod
|
||||
async def acompletion(self, messages: list[dict], timeout=3):
|
||||
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
|
||||
"""Asynchronous version of completion
|
||||
All GPTAPIs are required to provide the standard OpenAI completion interface
|
||||
[
|
||||
|
|
@ -181,7 +182,7 @@ class BaseLLM(ABC):
|
|||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
|
||||
"""_achat_completion_stream implemented by inherited class"""
|
||||
|
||||
@retry(
|
||||
|
|
@ -191,11 +192,13 @@ class BaseLLM(ABC):
|
|||
retry=retry_if_exception_type(ConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream: bool = False, timeout: int = 3) -> str:
|
||||
async def acompletion_text(
|
||||
self, messages: list[dict], stream: bool = False, timeout: int = USE_CONFIG_TIMEOUT
|
||||
) -> str:
|
||||
"""Asynchronous version of completion. Return str. Support stream-print"""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages, timeout=timeout)
|
||||
resp = await self._achat_completion(messages, timeout=timeout)
|
||||
return await self._achat_completion_stream(messages, timeout=self.get_timeout(timeout))
|
||||
resp = await self._achat_completion(messages, timeout=self.get_timeout(timeout))
|
||||
return self.get_choice_text(resp)
|
||||
|
||||
def get_choice_text(self, rsp: dict) -> str:
|
||||
|
|
@ -258,3 +261,6 @@ class BaseLLM(ABC):
|
|||
"""Set model and return self. For example, `with_model("gpt-3.5-turbo")`."""
|
||||
self.config.model = model
|
||||
return self
|
||||
|
||||
def get_timeout(self, timeout: int) -> int:
|
||||
return timeout or self.config.timeout or LLM_API_TIMEOUT
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ from dashscope.common.error import (
|
|||
UnsupportedApiProtocol,
|
||||
)
|
||||
|
||||
from metagpt.const import USE_CONFIG_TIMEOUT
|
||||
from metagpt.logs import log_llm_stream
|
||||
from metagpt.provider.base_llm import BaseLLM, LLMConfig
|
||||
from metagpt.provider.llm_provider_registry import LLMType, register_provider
|
||||
|
|
@ -202,16 +203,16 @@ class DashScopeLLM(BaseLLM):
|
|||
self._update_costs(dict(resp.usage))
|
||||
return resp.output
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> GenerationOutput:
|
||||
async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> GenerationOutput:
|
||||
resp: GenerationResponse = await self.aclient.acall(**self._const_kwargs(messages, stream=False))
|
||||
self._check_response(resp)
|
||||
self._update_costs(dict(resp.usage))
|
||||
return resp.output
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=3) -> GenerationOutput:
|
||||
return await self._achat_completion(messages, timeout=timeout)
|
||||
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> GenerationOutput:
|
||||
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
|
||||
resp = await self.aclient.acall(**self._const_kwargs(messages, stream=True))
|
||||
collected_content = []
|
||||
usage = {}
|
||||
|
|
|
|||
|
|
@ -573,7 +573,7 @@ class APIRequestor:
|
|||
total=request_timeout[1],
|
||||
)
|
||||
else:
|
||||
timeout = aiohttp.ClientTimeout(total=request_timeout if request_timeout else TIMEOUT_SECS)
|
||||
timeout = aiohttp.ClientTimeout(total=request_timeout or TIMEOUT_SECS)
|
||||
|
||||
if files:
|
||||
# TODO: Use `aiohttp.MultipartWriter` to create the multipart form data here.
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc : Google Gemini LLM from https://ai.google.dev/tutorials/python_quickstart
|
||||
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
import google.generativeai as genai
|
||||
|
|
@ -15,7 +16,8 @@ from google.generativeai.types.generation_types import (
|
|||
)
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.logs import log_llm_stream
|
||||
from metagpt.const import USE_CONFIG_TIMEOUT
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.schema import Message
|
||||
|
|
@ -52,6 +54,10 @@ class GeminiLLM(BaseLLM):
|
|||
self.llm = GeminiGenerativeModel(model_name=self.model)
|
||||
|
||||
def __init_gemini(self, config: LLMConfig):
|
||||
if config.proxy:
|
||||
logger.info(f"Use proxy: {config.proxy}")
|
||||
os.environ["HTTP_PROXY"] = config.proxy
|
||||
os.environ["HTTP_PROXYS"] = config.proxy
|
||||
genai.configure(api_key=config.api_key)
|
||||
|
||||
def _user_msg(self, msg: str, images: Optional[Union[str, list[str]]] = None) -> dict[str, str]:
|
||||
|
|
@ -118,16 +124,18 @@ class GeminiLLM(BaseLLM):
|
|||
self._update_costs(usage)
|
||||
return resp
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> "AsyncGenerateContentResponse":
|
||||
async def _achat_completion(
|
||||
self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT
|
||||
) -> "AsyncGenerateContentResponse":
|
||||
resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(**self._const_kwargs(messages))
|
||||
usage = await self.aget_usage(messages, resp.text)
|
||||
self._update_costs(usage)
|
||||
return resp
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=3) -> dict:
|
||||
return await self._achat_completion(messages, timeout=timeout)
|
||||
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
|
||||
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
|
||||
resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(
|
||||
**self._const_kwargs(messages, stream=True)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ Author: garylin2099
|
|||
from typing import Optional
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig
|
||||
from metagpt.const import USE_CONFIG_TIMEOUT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
|
||||
|
|
@ -18,7 +19,7 @@ class HumanProvider(BaseLLM):
|
|||
def __init__(self, config: LLMConfig):
|
||||
pass
|
||||
|
||||
def ask(self, msg: str, timeout=3) -> str:
|
||||
def ask(self, msg: str, timeout=USE_CONFIG_TIMEOUT) -> str:
|
||||
logger.info("It's your turn, please type in your response. You may also refer to the context below")
|
||||
rsp = input(msg)
|
||||
if rsp in ["exit", "quit"]:
|
||||
|
|
@ -31,20 +32,20 @@ class HumanProvider(BaseLLM):
|
|||
system_msgs: Optional[list[str]] = None,
|
||||
format_msgs: Optional[list[dict[str, str]]] = None,
|
||||
generator: bool = False,
|
||||
timeout=3,
|
||||
timeout=USE_CONFIG_TIMEOUT,
|
||||
) -> str:
|
||||
return self.ask(msg, timeout=timeout)
|
||||
return self.ask(msg, timeout=self.get_timeout(timeout))
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout=3):
|
||||
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
|
||||
pass
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=3):
|
||||
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
|
||||
"""dummy implementation of abstract method in base"""
|
||||
return []
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
|
||||
pass
|
||||
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=USE_CONFIG_TIMEOUT) -> str:
|
||||
"""dummy implementation of abstract method in base"""
|
||||
return ""
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
import json
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.const import LLM_API_TIMEOUT
|
||||
from metagpt.const import USE_CONFIG_TIMEOUT
|
||||
from metagpt.logs import log_llm_stream
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.general_api_requestor import GeneralAPIRequestor
|
||||
|
|
@ -50,28 +50,28 @@ class OllamaLLM(BaseLLM):
|
|||
chunk = chunk.decode(encoding)
|
||||
return json.loads(chunk)
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> dict:
|
||||
async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> dict:
|
||||
resp, _, _ = await self.client.arequest(
|
||||
method=self.http_method,
|
||||
url=self.suffix_url,
|
||||
params=self._const_kwargs(messages),
|
||||
request_timeout=LLM_API_TIMEOUT,
|
||||
request_timeout=self.get_timeout(timeout),
|
||||
)
|
||||
resp = self._decode_and_load(resp)
|
||||
usage = self.get_usage(resp)
|
||||
self._update_costs(usage)
|
||||
return resp
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=3) -> dict:
|
||||
return await self._achat_completion(messages, timeout=timeout)
|
||||
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
|
||||
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
|
||||
stream_resp, _, _ = await self.client.arequest(
|
||||
method=self.http_method,
|
||||
url=self.suffix_url,
|
||||
stream=True,
|
||||
params=self._const_kwargs(messages, stream=True),
|
||||
request_timeout=LLM_API_TIMEOUT,
|
||||
request_timeout=self.get_timeout(timeout),
|
||||
)
|
||||
|
||||
collected_content = []
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ from tenacity import (
|
|||
)
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.const import USE_CONFIG_TIMEOUT
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA
|
||||
|
|
@ -74,9 +75,9 @@ class OpenAILLM(BaseLLM):
|
|||
|
||||
return params
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str:
|
||||
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
|
||||
**self._cons_kwargs(messages, timeout=timeout), stream=True
|
||||
**self._cons_kwargs(messages, timeout=self.get_timeout(timeout)), stream=True
|
||||
)
|
||||
usage = None
|
||||
collected_messages = []
|
||||
|
|
@ -104,7 +105,7 @@ class OpenAILLM(BaseLLM):
|
|||
self._update_costs(usage)
|
||||
return full_reply_content
|
||||
|
||||
def _cons_kwargs(self, messages: list[dict], timeout=3, **extra_kwargs) -> dict:
|
||||
def _cons_kwargs(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT, **extra_kwargs) -> dict:
|
||||
kwargs = {
|
||||
"messages": messages,
|
||||
"max_tokens": self._get_max_tokens(messages),
|
||||
|
|
@ -112,20 +113,20 @@ class OpenAILLM(BaseLLM):
|
|||
# "stop": None, # default it's None and gpt4-v can't have this one
|
||||
"temperature": self.config.temperature,
|
||||
"model": self.model,
|
||||
"timeout": max(self.config.timeout, timeout),
|
||||
"timeout": self.get_timeout(timeout),
|
||||
}
|
||||
if extra_kwargs:
|
||||
kwargs.update(extra_kwargs)
|
||||
return kwargs
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout=3) -> ChatCompletion:
|
||||
kwargs = self._cons_kwargs(messages, timeout=timeout)
|
||||
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> ChatCompletion:
|
||||
kwargs = self._cons_kwargs(messages, timeout=self.get_timeout(timeout))
|
||||
rsp: ChatCompletion = await self.aclient.chat.completions.create(**kwargs)
|
||||
self._update_costs(rsp.usage)
|
||||
return rsp
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=3) -> ChatCompletion:
|
||||
return await self._achat_completion(messages, timeout=timeout)
|
||||
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> ChatCompletion:
|
||||
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
|
||||
|
||||
@retry(
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
|
|
@ -134,24 +135,24 @@ class OpenAILLM(BaseLLM):
|
|||
retry=retry_if_exception_type(APIConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=USE_CONFIG_TIMEOUT) -> str:
|
||||
"""when streaming, print each token in place."""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages, timeout=timeout)
|
||||
|
||||
rsp = await self._achat_completion(messages, timeout=timeout)
|
||||
rsp = await self._achat_completion(messages, timeout=self.get_timeout(timeout))
|
||||
return self.get_choice_text(rsp)
|
||||
|
||||
async def _achat_completion_function(
|
||||
self, messages: list[dict], timeout: int = 3, **chat_configs
|
||||
self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT, **chat_configs
|
||||
) -> ChatCompletion:
|
||||
messages = self.format_msg(messages)
|
||||
kwargs = self._cons_kwargs(messages=messages, timeout=timeout, **chat_configs)
|
||||
kwargs = self._cons_kwargs(messages=messages, timeout=self.get_timeout(timeout), **chat_configs)
|
||||
rsp: ChatCompletion = await self.aclient.chat.completions.create(**kwargs)
|
||||
self._update_costs(rsp.usage)
|
||||
return rsp
|
||||
|
||||
async def aask_code(self, messages: list[dict], timeout: int = 3, **kwargs) -> dict:
|
||||
async def aask_code(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT, **kwargs) -> dict:
|
||||
"""Use function of tools to ask a code.
|
||||
Note: Keep kwargs consistent with https://platform.openai.com/docs/api-reference/chat/create
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from qianfan import ChatCompletion
|
|||
from qianfan.resources.typing import JsonBody
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.const import USE_CONFIG_TIMEOUT
|
||||
from metagpt.logs import log_llm_stream
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
|
|
@ -107,15 +108,15 @@ class QianFanLLM(BaseLLM):
|
|||
self._update_costs(resp.body.get("usage", {}))
|
||||
return resp.body
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> JsonBody:
|
||||
async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> JsonBody:
|
||||
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=False))
|
||||
self._update_costs(resp.body.get("usage", {}))
|
||||
return resp.body
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout: int = 3) -> JsonBody:
|
||||
return await self._achat_completion(messages, timeout=timeout)
|
||||
async def acompletion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> JsonBody:
|
||||
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
|
||||
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=True))
|
||||
collected_content = []
|
||||
usage = {}
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from wsgiref.handlers import format_date_time
|
|||
import websocket # 使用websocket_client
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.const import USE_CONFIG_TIMEOUT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
|
|
@ -31,19 +32,19 @@ class SparkLLM(BaseLLM):
|
|||
def get_choice_text(self, rsp: dict) -> str:
|
||||
return rsp["payload"]["choices"]["text"][-1]["content"]
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
|
||||
pass
|
||||
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = USE_CONFIG_TIMEOUT) -> str:
|
||||
# 不支持
|
||||
# logger.warning("当前方法无法支持异步运行。当你使用acompletion时,并不能并行访问。")
|
||||
w = GetMessageFromWeb(messages, self.config)
|
||||
return w.run()
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout=3):
|
||||
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
|
||||
pass
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=3):
|
||||
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
|
||||
# 不支持异步
|
||||
w = GetMessageFromWeb(messages, self.config)
|
||||
return w.run()
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from typing import Optional
|
|||
from zhipuai.types.chat.chat_completion import Completion
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.const import USE_CONFIG_TIMEOUT
|
||||
from metagpt.logs import log_llm_stream
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
|
|
@ -45,22 +46,22 @@ class ZhiPuAILLM(BaseLLM):
|
|||
kwargs = {"model": self.model, "messages": messages, "stream": stream, "temperature": 0.3}
|
||||
return kwargs
|
||||
|
||||
def completion(self, messages: list[dict], timeout=3) -> dict:
|
||||
def completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
|
||||
resp: Completion = self.llm.chat.completions.create(**self._const_kwargs(messages))
|
||||
usage = resp.usage.model_dump()
|
||||
self._update_costs(usage)
|
||||
return resp.model_dump()
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout=3) -> dict:
|
||||
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
|
||||
resp = await self.llm.acreate(**self._const_kwargs(messages))
|
||||
usage = resp.get("usage", {})
|
||||
self._update_costs(usage)
|
||||
return resp
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=3) -> dict:
|
||||
return await self._achat_completion(messages, timeout=timeout)
|
||||
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
|
||||
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str:
|
||||
response = await self.llm.acreate_stream(**self._const_kwargs(messages, stream=True))
|
||||
collected_content = []
|
||||
usage = {}
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ from metagpt.roles.engineer import Engineer
|
|||
from metagpt.roles.qa_engineer import QaEngineer
|
||||
from metagpt.roles.searcher import Searcher
|
||||
from metagpt.roles.sales import Sales
|
||||
from metagpt.roles.customer_service import CustomerService
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -26,5 +25,4 @@ __all__ = [
|
|||
"QaEngineer",
|
||||
"Searcher",
|
||||
"Sales",
|
||||
"CustomerService",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ import csv
|
|||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
|
|
@ -29,6 +30,7 @@ from typing import Any, Callable, List, Literal, Tuple, Union
|
|||
from urllib.parse import quote, unquote
|
||||
|
||||
import aiofiles
|
||||
import chardet
|
||||
import loguru
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
|
@ -663,14 +665,21 @@ def role_raise_decorator(func):
|
|||
|
||||
|
||||
@handle_exception
|
||||
async def aread(filename: str | Path, encoding=None) -> str:
|
||||
async def aread(filename: str | Path, encoding="utf-8") -> str:
|
||||
"""Read file asynchronously."""
|
||||
async with aiofiles.open(str(filename), mode="r", encoding=encoding) as reader:
|
||||
content = await reader.read()
|
||||
try:
|
||||
async with aiofiles.open(str(filename), mode="r", encoding=encoding) as reader:
|
||||
content = await reader.read()
|
||||
except UnicodeDecodeError:
|
||||
async with aiofiles.open(str(filename), mode="rb") as reader:
|
||||
raw = await reader.read()
|
||||
result = chardet.detect(raw)
|
||||
detected_encoding = result["encoding"]
|
||||
content = raw.decode(detected_encoding)
|
||||
return content
|
||||
|
||||
|
||||
async def awrite(filename: str | Path, data: str, encoding=None):
|
||||
async def awrite(filename: str | Path, data: str, encoding="utf-8"):
|
||||
"""Write file asynchronously."""
|
||||
pathname = Path(filename)
|
||||
pathname.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -811,3 +820,21 @@ See FAQ 5.8
|
|||
"""
|
||||
)
|
||||
raise retry_state.outcome.exception()
|
||||
|
||||
|
||||
def get_markdown_codeblock_type(filename: str) -> str:
|
||||
"""Return the markdown code-block type corresponding to the file extension."""
|
||||
mime_type, _ = mimetypes.guess_type(filename)
|
||||
mappings = {
|
||||
"text/x-shellscript": "bash",
|
||||
"text/x-c++src": "cpp",
|
||||
"text/css": "css",
|
||||
"text/html": "html",
|
||||
"text/x-java": "java",
|
||||
"application/javascript": "javascript",
|
||||
"application/json": "json",
|
||||
"text/x-python": "python",
|
||||
"text/x-ruby": "ruby",
|
||||
"application/sql": "sql",
|
||||
}
|
||||
return mappings.get(mime_type, "text")
|
||||
|
|
|
|||
|
|
@ -13,9 +13,7 @@ import re
|
|||
from pathlib import Path
|
||||
from typing import Set
|
||||
|
||||
import aiofiles
|
||||
|
||||
from metagpt.utils.common import aread
|
||||
from metagpt.utils.common import aread, awrite
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
|
||||
|
||||
|
|
@ -45,8 +43,7 @@ class DependencyFile:
|
|||
async def save(self):
|
||||
"""Save dependencies to the file asynchronously."""
|
||||
data = json.dumps(self._dependencies)
|
||||
async with aiofiles.open(str(self._filename), mode="w") as writer:
|
||||
await writer.write(data)
|
||||
await awrite(filename=self._filename, data=data)
|
||||
|
||||
async def update(self, filename: Path | str, dependencies: Set[Path | str], persist=True):
|
||||
"""Update dependencies for a file asynchronously.
|
||||
|
|
|
|||
|
|
@ -14,11 +14,9 @@ from datetime import datetime
|
|||
from pathlib import Path
|
||||
from typing import Dict, List, Set
|
||||
|
||||
import aiofiles
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Document
|
||||
from metagpt.utils.common import aread
|
||||
from metagpt.utils.common import aread, awrite
|
||||
from metagpt.utils.json_to_markdown import json_to_markdown
|
||||
|
||||
|
||||
|
|
@ -55,8 +53,7 @@ class FileRepository:
|
|||
pathname = self.workdir / filename
|
||||
pathname.parent.mkdir(parents=True, exist_ok=True)
|
||||
content = content if content else "" # avoid `argument must be str, not None` to make it continue
|
||||
async with aiofiles.open(str(pathname), mode="w") as writer:
|
||||
await writer.write(content)
|
||||
await awrite(filename=str(pathname), data=content)
|
||||
logger.info(f"save to: {str(pathname)}")
|
||||
|
||||
if dependencies is not None:
|
||||
|
|
|
|||
|
|
@ -9,11 +9,9 @@ import asyncio
|
|||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import aiofiles
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.common import check_cmd_exists
|
||||
from metagpt.utils.common import awrite, check_cmd_exists
|
||||
|
||||
|
||||
async def mermaid_to_file(engine, mermaid_code, output_file_without_suffix, width=2048, height=2048) -> int:
|
||||
|
|
@ -30,9 +28,7 @@ async def mermaid_to_file(engine, mermaid_code, output_file_without_suffix, widt
|
|||
if dir_name and not os.path.exists(dir_name):
|
||||
os.makedirs(dir_name)
|
||||
tmp = Path(f"{output_file_without_suffix}.mmd")
|
||||
async with aiofiles.open(tmp, "w", encoding="utf-8") as f:
|
||||
await f.write(mermaid_code)
|
||||
# tmp.write_text(mermaid_code, encoding="utf-8")
|
||||
await awrite(filename=tmp, data=mermaid_code)
|
||||
|
||||
if engine == "nodejs":
|
||||
if check_cmd_exists(config.mermaid.path) != 0:
|
||||
|
|
|
|||
|
|
@ -340,7 +340,9 @@ def extract_state_value_from_output(content: str) -> str:
|
|||
content (str): llm's output from `Role._think`
|
||||
"""
|
||||
content = content.strip() # deal the output cases like " 0", "0\n" and so on.
|
||||
pattern = r"([0-9])" # TODO find the number using a more proper method not just extract from content using pattern
|
||||
pattern = (
|
||||
r"(?<!-)[0-9]" # TODO find the number using a more proper method not just extract from content using pattern
|
||||
)
|
||||
matches = re.findall(pattern, content, re.DOTALL)
|
||||
matches = list(set(matches))
|
||||
state = matches[0] if len(matches) > 0 else "-1"
|
||||
|
|
|
|||
80
metagpt/utils/repo_to_markdown.py
Normal file
80
metagpt/utils/repo_to_markdown.py
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
This file provides functionality to convert a local repository into a markdown representation.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
|
||||
from gitignore_parser import parse_gitignore
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.common import aread, awrite, get_markdown_codeblock_type, list_files
|
||||
from metagpt.utils.tree import tree
|
||||
|
||||
|
||||
async def repo_to_markdown(repo_path: str | Path, output: str | Path = None, gitignore: str | Path = None) -> str:
|
||||
"""
|
||||
Convert a local repository into a markdown representation.
|
||||
|
||||
This function takes a path to a local repository and generates a markdown representation of the repository structure,
|
||||
including directory trees and file listings.
|
||||
|
||||
Args:
|
||||
repo_path (str | Path): The path to the local repository.
|
||||
output (str | Path, optional): The path to save the generated markdown file. Defaults to None.
|
||||
gitignore (str | Path, optional): The path to the .gitignore file. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: The markdown representation of the repository.
|
||||
"""
|
||||
repo_path = Path(repo_path)
|
||||
gitignore = Path(gitignore or Path(__file__).parent / "../../.gitignore").resolve()
|
||||
|
||||
markdown = await _write_dir_tree(repo_path=repo_path, gitignore=gitignore)
|
||||
|
||||
gitignore_rules = parse_gitignore(full_path=str(gitignore))
|
||||
markdown += await _write_files(repo_path=repo_path, gitignore_rules=gitignore_rules)
|
||||
|
||||
if output:
|
||||
await awrite(filename=str(output), data=markdown, encoding="utf-8")
|
||||
return markdown
|
||||
|
||||
|
||||
async def _write_dir_tree(repo_path: Path, gitignore: Path) -> str:
|
||||
try:
|
||||
content = tree(repo_path, gitignore, run_command=True)
|
||||
except Exception as e:
|
||||
logger.info(f"{e}, using safe mode.")
|
||||
content = tree(repo_path, gitignore, run_command=False)
|
||||
|
||||
doc = f"## Directory Tree\n```text\n{content}\n```\n---\n\n"
|
||||
return doc
|
||||
|
||||
|
||||
async def _write_files(repo_path, gitignore_rules) -> str:
|
||||
filenames = list_files(repo_path)
|
||||
markdown = ""
|
||||
for filename in filenames:
|
||||
if gitignore_rules(str(filename)):
|
||||
continue
|
||||
markdown += await _write_file(filename=filename, repo_path=repo_path)
|
||||
return markdown
|
||||
|
||||
|
||||
async def _write_file(filename: Path, repo_path: Path) -> str:
|
||||
relative_path = filename.relative_to(repo_path)
|
||||
markdown = f"## {relative_path}\n"
|
||||
|
||||
mime_type, _ = mimetypes.guess_type(filename.name)
|
||||
if "text/" not in mime_type:
|
||||
logger.info(f"Ignore content: {filename}")
|
||||
markdown += "<binary file>\n---\n\n"
|
||||
return markdown
|
||||
content = await aread(filename, encoding="utf-8")
|
||||
content = content.replace("```", "\\`\\`\\`").replace("---", "\\-\\-\\-")
|
||||
code_block_type = get_markdown_codeblock_type(filename.name)
|
||||
markdown += f"```{code_block_type}\n{content}\n```\n---\n\n"
|
||||
return markdown
|
||||
140
metagpt/utils/tree.py
Normal file
140
metagpt/utils/tree.py
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/3/11
|
||||
@Author : mashenquan
|
||||
@File : tree.py
|
||||
@Desc : Implement the same functionality as the `tree` command.
|
||||
Example:
|
||||
>>> print_tree(".")
|
||||
utils
|
||||
+-- serialize.py
|
||||
+-- project_repo.py
|
||||
+-- tree.py
|
||||
+-- mmdc_playwright.py
|
||||
+-- cost_manager.py
|
||||
+-- __pycache__
|
||||
| +-- __init__.cpython-39.pyc
|
||||
| +-- redis.cpython-39.pyc
|
||||
| +-- singleton.cpython-39.pyc
|
||||
| +-- embedding.cpython-39.pyc
|
||||
| +-- make_sk_kernel.cpython-39.pyc
|
||||
| +-- file_repository.cpython-39.pyc
|
||||
+-- file.py
|
||||
+-- save_code.py
|
||||
+-- common.py
|
||||
+-- redis.py
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List
|
||||
|
||||
from gitignore_parser import parse_gitignore
|
||||
|
||||
|
||||
def tree(root: str | Path, gitignore: str | Path = None, run_command: bool = False) -> str:
|
||||
"""
|
||||
Recursively traverses the directory structure and prints it out in a tree-like format.
|
||||
|
||||
Args:
|
||||
root (str or Path): The root directory from which to start traversing.
|
||||
gitignore (str or Path): The filename of gitignore file.
|
||||
run_command (bool): Whether to execute `tree` command. Execute the `tree` command and return the result if True,
|
||||
otherwise execute python code instead.
|
||||
|
||||
Returns:
|
||||
str: A string representation of the directory tree.
|
||||
|
||||
Example:
|
||||
>>> tree(".")
|
||||
utils
|
||||
+-- serialize.py
|
||||
+-- project_repo.py
|
||||
+-- tree.py
|
||||
+-- mmdc_playwright.py
|
||||
+-- __pycache__
|
||||
| +-- __init__.cpython-39.pyc
|
||||
| +-- redis.cpython-39.pyc
|
||||
| +-- singleton.cpython-39.pyc
|
||||
+-- parse_docstring.py
|
||||
|
||||
>>> tree(".", gitignore="../../.gitignore")
|
||||
utils
|
||||
+-- serialize.py
|
||||
+-- project_repo.py
|
||||
+-- tree.py
|
||||
+-- mmdc_playwright.py
|
||||
+-- parse_docstring.py
|
||||
|
||||
>>> tree(".", gitignore="../../.gitignore", run_command=True)
|
||||
utils
|
||||
├── serialize.py
|
||||
├── project_repo.py
|
||||
├── tree.py
|
||||
├── mmdc_playwright.py
|
||||
└── parse_docstring.py
|
||||
|
||||
|
||||
"""
|
||||
root = Path(root).resolve()
|
||||
if run_command:
|
||||
return _execute_tree(root, gitignore)
|
||||
|
||||
git_ignore_rules = parse_gitignore(gitignore) if gitignore else None
|
||||
dir_ = {root.name: _list_children(root=root, git_ignore_rules=git_ignore_rules)}
|
||||
v = _print_tree(dir_)
|
||||
return "\n".join(v)
|
||||
|
||||
|
||||
def _list_children(root: Path, git_ignore_rules: Callable) -> Dict[str, Dict]:
|
||||
dir_ = {}
|
||||
for i in root.iterdir():
|
||||
if git_ignore_rules and git_ignore_rules(str(i)):
|
||||
continue
|
||||
try:
|
||||
if i.is_file():
|
||||
dir_[i.name] = {}
|
||||
else:
|
||||
dir_[i.name] = _list_children(root=i, git_ignore_rules=git_ignore_rules)
|
||||
except (FileNotFoundError, PermissionError, OSError):
|
||||
dir_[i.name] = {}
|
||||
return dir_
|
||||
|
||||
|
||||
def _print_tree(dir_: Dict[str:Dict]) -> List[str]:
|
||||
ret = []
|
||||
for name, children in dir_.items():
|
||||
ret.append(name)
|
||||
if not children:
|
||||
continue
|
||||
lines = _print_tree(children)
|
||||
for j, v in enumerate(lines):
|
||||
if v[0] not in ["+", " ", "|"]:
|
||||
ret = _add_line(ret)
|
||||
row = f"+-- {v}"
|
||||
else:
|
||||
row = f" {v}"
|
||||
ret.append(row)
|
||||
return ret
|
||||
|
||||
|
||||
def _add_line(rows: List[str]) -> List[str]:
|
||||
for i in range(len(rows) - 1, -1, -1):
|
||||
v = rows[i]
|
||||
if v[0] != " ":
|
||||
return rows
|
||||
rows[i] = "|" + v[1:]
|
||||
return rows
|
||||
|
||||
|
||||
def _execute_tree(root: Path, gitignore: str | Path) -> str:
|
||||
args = ["--gitfile", str(gitignore)] if gitignore else []
|
||||
try:
|
||||
result = subprocess.run(["tree"] + args + [str(root)], capture_output=True, text=True, check=True)
|
||||
if result.returncode != 0:
|
||||
raise ValueError(f"tree exits with code {result.returncode}")
|
||||
return result.stdout
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise e
|
||||
Loading…
Add table
Add a link
Reference in a new issue