Merge branch 'dev' of https://github.com/geekan/MetaGPT into geekan/dev

This commit is contained in:
莘权 马 2023-12-25 22:39:17 +08:00
commit ef1bc01c99
33 changed files with 714 additions and 140 deletions

View file

@ -105,6 +105,7 @@ You are a member of a professional butler team and will provide helpful suggesti
"""
# TOTEST
class SearchAndSummarize(Action):
name: str = ""
content: Optional[str] = None

View file

@ -20,6 +20,7 @@ from metagpt.logs import logger
from metagpt.schema import Message
# TOTEST
class ArgumentsParingAction(Action):
skill: Skill
ask: str

View file

@ -91,6 +91,7 @@ flowchart TB
"""
# TOTEST
class SummarizeCode(Action):
name: str = "SummarizeCode"
context: CodeSummarizeContext = Field(default_factory=CodeSummarizeContext)

View file

@ -15,6 +15,7 @@ from metagpt.logs import logger
from metagpt.schema import Message
# TOTEST
class TalkAction(Action):
context: str
history_summary: str = ""

View file

@ -81,6 +81,7 @@ class Config(metaclass=Singleton):
logger.debug("Config loading done.")
def get_default_llm_provider_enum(self) -> LLMProviderEnum:
"""Get first valid LLM provider enum"""
mappings = {
LLMProviderEnum.OPENAI: bool(
self._is_valid_llm_key(self.OPENAI_API_KEY) and not self.OPENAI_API_TYPE and self.OPENAI_API_MODEL

View file

@ -7,13 +7,13 @@
"""
import anthropic
from anthropic import Anthropic
from anthropic import Anthropic, AsyncAnthropic
from metagpt.config import CONFIG
class Claude2:
def ask(self, prompt):
def ask(self, prompt: str) -> str:
client = Anthropic(api_key=CONFIG.anthropic_api_key)
res = client.completions.create(
@ -23,10 +23,10 @@ class Claude2:
)
return res.completion
async def aask(self, prompt):
client = Anthropic(api_key=CONFIG.anthropic_api_key)
async def aask(self, prompt: str) -> str:
aclient = AsyncAnthropic(api_key=CONFIG.anthropic_api_key)
res = client.completions.create(
res = await aclient.completions.create(
model="claude-2",
prompt=f"{anthropic.HUMAN_PROMPT} {prompt} {anthropic.AI_PROMPT}",
max_tokens_to_sample=1000,

View file

@ -162,7 +162,7 @@ class BaseGPTAPI(BaseChatbot):
def messages_to_prompt(self, messages: list[dict]):
"""[{"role": "user", "content": msg}] to user: <msg> etc."""
return "\n".join([f"{i['role']}: {i['content']}" for i in messages])
return "\n".join([f"{i.role}: {i.content}" for i in messages])
def messages_to_dict(self, messages):
"""objects to [{"role": "user", "content": msg}] etc."""

View file

@ -133,7 +133,9 @@ class FireWorksGPTAPI(OpenAIGPTAPI):
retry=retry_if_exception_type(APIConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str:
async def acompletion_text(
self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3
) -> str:
"""when streaming, print each token in place."""
if stream:
return await self._achat_completion_stream(messages)

View file

@ -79,6 +79,9 @@ class GeminiGPTAPI(BaseGPTAPI):
except Exception as e:
logger.error(f"google gemini updats costs failed! exp: {e}")
def close(self):
pass
def get_choice_text(self, resp: GenerateContentResponse) -> str:
return resp.text
@ -133,7 +136,9 @@ class GeminiGPTAPI(BaseGPTAPI):
retry=retry_if_exception_type(ConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
async def acompletion_text(
self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3
) -> str:
"""response in async with stream or non-stream mode"""
if stream:
return await self._achat_completion_stream(messages)

View file

@ -57,6 +57,9 @@ class OllamaGPTAPI(BaseGPTAPI):
self.model = config.ollama_api_model
def close(self):
pass
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream}
return kwargs
@ -144,7 +147,9 @@ class OllamaGPTAPI(BaseGPTAPI):
retry=retry_if_exception_type(ConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
async def acompletion_text(
self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3
) -> str:
"""response in async with stream or non-stream mode"""
if stream:
return await self._achat_completion_stream(messages)

View file

@ -26,16 +26,19 @@ from metagpt.provider.llm_provider_registry import register_provider
@register_provider(LLMProviderEnum.SPARK)
class SparkAPI(BaseGPTAPI):
class SparkGPTAPI(BaseGPTAPI):
def __init__(self):
logger.warning("当前方法无法支持异步运行。当你使用acompletion时并不能并行访问。")
def close(self):
pass
def ask(self, msg: str) -> str:
message = [self._default_system_msg(), self._user_msg(msg)]
rsp = self.completion(message)
return rsp
async def aask(self, msg: str, system_msgs: Optional[list[str]] = None) -> str:
async def aask(self, msg: str, system_msgs: Optional[list[str]] = None, stream: bool = True) -> str:
if system_msgs:
message = self._system_msgs(system_msgs) + [self._user_msg(msg)]
else:
@ -47,7 +50,9 @@ class SparkAPI(BaseGPTAPI):
def get_choice_text(self, rsp: dict) -> str:
return rsp["payload"]["choices"]["text"][-1]["content"]
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
async def acompletion_text(
self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3
) -> str:
# 不支持
logger.error("该功能禁用。")
w = GetMessageFromWeb(messages)

View file

@ -64,6 +64,9 @@ class ZhiPuAIGPTAPI(BaseGPTAPI):
except Exception as e:
logger.error(f"zhipuai updats costs failed! exp: {e}")
def close(self):
pass
def get_choice_text(self, resp: dict) -> str:
"""get the first text of choice from llm response"""
assist_msg = resp.get("data", {}).get("choices", [{"role": "error"}])[-1]
@ -131,6 +134,6 @@ class ZhiPuAIGPTAPI(BaseGPTAPI):
async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str:
"""response in async with stream or non-stream mode"""
if stream:
return await self._achat_completion_stream(messages, timeout=timeout)
return await self._achat_completion_stream(messages)
resp = await self._achat_completion(messages)
return self.get_choice_text(resp)

View file

@ -48,7 +48,7 @@ def check_cmd_exists(command) -> int:
return result
def require_python_version(req_version: tuple[int]) -> bool:
def require_python_version(req_version: Tuple) -> bool:
if not (2 <= len(req_version) <= 3):
raise ValueError("req_version should be (3, 9) or (3, 10, 13)")
return True if sys.version_info > req_version else False