Merge pull request #621 from shenchucheng/feature-huggingface

Enables MetaGPT to be used as a dependency for web applications, such as https://huggingface.co/spaces/deepwisdom/MetaGPT.
This commit is contained in:
geekan 2023-12-25 17:51:49 +08:00 committed by GitHub
commit 6ba1a897c1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 39 additions and 15 deletions

View file

@ -117,4 +117,6 @@ RPM: 10
### repair operation on the content extracted from LLM's raw output. Warning, it improves the result but not fix all cases.
# REPAIR_LLM_OUTPUT: false
# PROMPT_FORMAT: json #json or markdown
# PROMPT_FORMAT: json #json or markdown
# DISABLE_LLM_PROVIDER_CHECK: false

View file

@ -180,9 +180,11 @@ class WebBrowseAndSummarize(Action):
llm: BaseGPTAPI = Field(default_factory=LLM)
desc: str = "Explore the web and provide summaries of articles and webpages."
browse_func: Union[Callable[[list[str]], None], None] = None
web_browser_engine: WebBrowserEngine = WebBrowserEngine(
engine=WebBrowserEngineType.CUSTOM if browse_func else None,
run_func=browse_func,
web_browser_engine: WebBrowserEngine = Field(
default_factory=lambda: WebBrowserEngine(
engine=WebBrowserEngineType.CUSTOM if WebBrowseAndSummarize.browse_func else None,
run_func=WebBrowseAndSummarize.browse_func,
)
)
def __init__(self, **kwargs):

View file

@ -107,7 +107,9 @@ class Config(metaclass=Singleton):
self.gemini_api_key = self._get("GEMINI_API_KEY")
self.ollama_api_base = self._get("OLLAMA_API_BASE")
self.ollama_api_model = self._get("OLLAMA_API_MODEL")
_ = self.get_default_llm_provider_enum()
if not self._get("DISABLE_LLM_PROVIDER_CHECK"):
_ = self.get_default_llm_provider_enum()
self.openai_base_url = self._get("OPENAI_BASE_URL")
self.openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy

View file

@ -6,6 +6,8 @@
@File : llm.py
"""
from typing import Optional
from metagpt.config import CONFIG, LLMProviderEnum
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.human_provider import HumanProvider
@ -14,6 +16,9 @@ from metagpt.provider.llm_provider_registry import LLM_REGISTRY
_ = HumanProvider() # Avoid pre-commit error
def LLM(provider: LLMProviderEnum = CONFIG.get_default_llm_provider_enum()) -> BaseGPTAPI:
def LLM(provider: Optional[LLMProviderEnum] = None) -> BaseGPTAPI:
"""get the default llm provider"""
if provider is None:
provider = CONFIG.get_default_llm_provider_enum()
return LLM_REGISTRY.get_provider(provider)

View file

@ -8,6 +8,7 @@
import sys
from datetime import datetime
from functools import partial
from loguru import logger as _logger
@ -26,3 +27,15 @@ def define_log_level(print_level="INFO", logfile_level="DEBUG"):
logger = define_log_level()
def log_llm_stream(msg):
_llm_stream_log(msg)
def set_llm_stream_logfunc(func):
global _llm_stream_log
_llm_stream_log = func
_llm_stream_log = partial(print, end="")

View file

@ -20,7 +20,7 @@ from tenacity import (
)
from metagpt.config import CONFIG, LLMProviderEnum
from metagpt.logs import logger
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import CostManager, log_and_reraise
@ -119,7 +119,7 @@ class GeminiGPTAPI(BaseGPTAPI):
collected_content = []
async for chunk in resp:
content = chunk.text
print(content, end="")
log_llm_stream(content)
collected_content.append(content)
full_content = "".join(collected_content)

View file

@ -15,7 +15,7 @@ from tenacity import (
from metagpt.config import CONFIG, LLMProviderEnum
from metagpt.const import LLM_API_TIMEOUT
from metagpt.logs import logger
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.general_api_requestor import GeneralAPIRequestor
from metagpt.provider.llm_provider_registry import register_provider
@ -127,7 +127,7 @@ class OllamaGPTAPI(BaseGPTAPI):
if not chunk.get("done", False):
content = self.get_choice_text(chunk)
collected_content.append(content)
print(content, end="")
log_llm_stream(content)
else:
# stream finished
usage = self.get_usage(chunk)

View file

@ -29,7 +29,7 @@ from tenacity import (
)
from metagpt.config import CONFIG, Config, LLMProviderEnum
from metagpt.logs import logger
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA, GENERAL_TOOL_CHOICE
from metagpt.provider.llm_provider_registry import register_provider
@ -222,7 +222,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
chunk_message = chunk.choices[0].delta # extract the message
collected_messages.append(chunk_message) # save the message
if chunk_message.content:
print(chunk_message.content, end="")
log_llm_stream(chunk_message.content)
print()
full_reply_content = "".join([m.content for m in collected_messages if m.content])

View file

@ -17,7 +17,7 @@ from tenacity import (
)
from metagpt.config import CONFIG, LLMProviderEnum
from metagpt.logs import logger
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import CostManager, log_and_reraise
@ -94,7 +94,7 @@ class ZhiPuAIGPTAPI(BaseGPTAPI):
if event.event == ZhiPuEvent.ADD.value:
content = event.data
collected_content.append(content)
print(content, end="")
log_llm_stream(content)
elif event.event == ZhiPuEvent.ERROR.value or event.event == ZhiPuEvent.INTERRUPTED.value:
content = event.data
logger.error(f"event error: {content}", end="")

View file

@ -152,7 +152,7 @@ class Role(BaseModel):
builtin_class_name: str = ""
_private_attributes = {
"_llm": LLM() if not is_human else HumanProvider(),
"_llm": None,
"_role_id": _role_id,
"_states": [],
"_actions": [],