Merge pull request #615 from iorisa/fixbug/geekan/dev

fixbug: timeout & prompt_format
This commit is contained in:
geekan 2023-12-24 15:35:28 +08:00 committed by GitHub
commit 8d1a3ce171
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 41 additions and 36 deletions

View file

@ -15,6 +15,7 @@ OPENAI_API_MODEL: "gpt-4-1106-preview"
MAX_TOKENS: 4096
RPM: 10
LLM_TYPE: OpenAI # Except for these three major models OpenAI, MetaGPT LLM, and Azure other large models can be distinguished based on the validity of the key.
TIMEOUT: 60 # Timeout for llm invocation
#### if Spark
#SPARK_APPID : "YOUR_APPID"

View file

@ -14,6 +14,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type
from pydantic import BaseModel, create_model, root_validator, validator
from tenacity import retry, stop_after_attempt, wait_random_exponential
from metagpt.config import CONFIG
from metagpt.llm import BaseGPTAPI
from metagpt.logs import logger
from metagpt.provider.postprecess.llm_output_postprecess import llm_output_postprecess
@ -260,9 +261,10 @@ class ActionNode:
output_data_mapping: dict,
system_msgs: Optional[list[str]] = None,
schema="markdown", # compatible to original format
timeout=CONFIG.timeout,
) -> (str, BaseModel):
"""Use ActionOutput to wrap the output of aask"""
content = await self.llm.aask(prompt, system_msgs)
content = await self.llm.aask(prompt, system_msgs, timeout=timeout)
logger.debug(f"llm raw output:\n{content}")
output_class = self.create_model_class(output_class_name, output_data_mapping)
@ -289,13 +291,13 @@ class ActionNode:
def set_context(self, context):
self.set_recursive("context", context)
async def simple_fill(self, schema, mode):
async def simple_fill(self, schema, mode, timeout=CONFIG.timeout):
prompt = self.compile(context=self.context, schema=schema, mode=mode)
if schema != "raw":
mapping = self.get_mapping(mode)
class_name = f"{self.key}_AN"
content, scontent = await self._aask_v1(prompt, class_name, mapping, schema=schema)
content, scontent = await self._aask_v1(prompt, class_name, mapping, schema=schema, timeout=timeout)
self.content = content
self.instruct_content = scontent
else:
@ -304,7 +306,7 @@ class ActionNode:
return self
async def fill(self, context, llm, schema="json", mode="auto", strgy="simple"):
async def fill(self, context, llm, schema="json", mode="auto", strgy="simple", timeout=CONFIG.timeout):
"""Fill the node(s) with mode.
:param context: Everything we should know when filling node.
@ -320,6 +322,7 @@ class ActionNode:
:param strgy: simple/complex
- simple: run only once
- complex: run each node
:param timeout: Timeout for llm invocation.
:return: self
"""
self.set_llm(llm)
@ -328,12 +331,12 @@ class ActionNode:
schema = self.schema
if strgy == "simple":
return await self.simple_fill(schema=schema, mode=mode)
return await self.simple_fill(schema=schema, mode=mode, timeout=timeout)
elif strgy == "complex":
# 这里隐式假设了拥有children
tmp = {}
for _, i in self.children.items():
child = await i.simple_fill(schema=schema, mode=mode)
child = await i.simple_fill(schema=schema, mode=mode, timeout=timeout)
tmp.update(child.instruct_content.dict())
cls = self.create_children_class()
self.instruct_content = cls(**tmp)

View file

@ -51,7 +51,7 @@ class WriteDesign(Action):
"clearly and in detail."
)
async def run(self, with_messages: Message, schema: str = CONFIG.prompt_schema):
async def run(self, with_messages: Message, schema: str = CONFIG.prompt_format):
# Use `git diff` to identify which PRD documents have been modified in the `docs/prds` directory.
prds_file_repo = CONFIG.git_repo.new_file_repository(PRDS_FILE_REPO)
changed_prds = prds_file_repo.changed_files
@ -81,11 +81,11 @@ class WriteDesign(Action):
# leaving room for global optimization in subsequent steps.
return ActionOutput(content=changed_files.json(), instruct_content=changed_files)
async def _new_system_design(self, context, schema=CONFIG.prompt_schema):
async def _new_system_design(self, context, schema=CONFIG.prompt_format):
node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, schema=schema)
return node
async def _merge(self, prd_doc, system_design_doc, schema=CONFIG.prompt_schema):
async def _merge(self, prd_doc, system_design_doc, schema=CONFIG.prompt_format):
context = NEW_REQ_TEMPLATE.format(old_design=system_design_doc.content, context=prd_doc.content)
node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, schema=schema)
system_design_doc.content = node.instruct_content.json(ensure_ascii=False)

View file

@ -45,7 +45,7 @@ class WriteTasks(Action):
context: Optional[str] = None
llm: BaseGPTAPI = Field(default_factory=LLM)
async def run(self, with_messages, schema=CONFIG.prompt_schema):
async def run(self, with_messages, schema=CONFIG.prompt_format):
system_design_file_repo = CONFIG.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO)
changed_system_designs = system_design_file_repo.changed_files
@ -92,14 +92,14 @@ class WriteTasks(Action):
await self._save_pdf(task_doc=task_doc)
return task_doc
async def _run_new_tasks(self, context, schema=CONFIG.prompt_schema):
async def _run_new_tasks(self, context, schema=CONFIG.prompt_format):
node = await PM_NODE.fill(context, self.llm, schema)
# prompt_template, format_example = get_template(templates, format)
# prompt = prompt_template.format(context=context, format_example=format_example)
# rsp = await self._aask_v1(prompt, "task", OUTPUT_MAPPING, format=format)
return node
async def _merge(self, system_design_doc, task_doc, schema=CONFIG.prompt_schema) -> Document:
async def _merge(self, system_design_doc, task_doc, schema=CONFIG.prompt_format) -> Document:
context = NEW_REQ_TEMPLATE.format(context=system_design_doc.content, old_tasks=task_doc.content)
node = await PM_NODE.fill(context, self.llm, schema)
task_doc.content = node.instruct_content.json(ensure_ascii=False)

View file

@ -181,7 +181,6 @@ class WebBrowseAndSummarize(Action):
desc: str = "Explore the web and provide summaries of articles and webpages."
browse_func: Union[Callable[[list[str]], None], None] = None
web_browser_engine: WebBrowserEngine = WebBrowserEngine(
options={}, # FIXME: REMOVE options?
engine=WebBrowserEngineType.CUSTOM if browse_func else None,
run_func=browse_func,
)

View file

@ -69,7 +69,7 @@ class WritePRD(Action):
content: Optional[str] = None
llm: BaseGPTAPI = Field(default_factory=LLM)
async def run(self, with_messages, schema=CONFIG.prompt_schema, *args, **kwargs) -> ActionOutput | Message:
async def run(self, with_messages, schema=CONFIG.prompt_format, *args, **kwargs) -> ActionOutput | Message:
# Determine which requirement documents need to be rewritten: Use LLM to assess whether new requirements are
# related to the PRD. If they are related, rewrite the PRD.
docs_file_repo = CONFIG.git_repo.new_file_repository(relative_path=DOCS_FILE_REPO)
@ -113,7 +113,7 @@ class WritePRD(Action):
# optimization in subsequent steps.
return ActionOutput(content=change_files.json(), instruct_content=change_files)
async def _run_new_requirement(self, requirements, schema=CONFIG.prompt_schema) -> ActionOutput:
async def _run_new_requirement(self, requirements, schema=CONFIG.prompt_format) -> ActionOutput:
# sas = SearchAndSummarize()
# # rsp = await sas.run(context=requirements, system_text=SEARCH_AND_SUMMARIZE_SYSTEM_EN_US)
# rsp = ""
@ -132,7 +132,7 @@ class WritePRD(Action):
node = await WP_IS_RELATIVE_NODE.fill(context, self.llm)
return node.get("is_relative") == "YES"
async def _merge(self, new_requirement_doc, prd_doc, schema=CONFIG.prompt_schema) -> Document:
async def _merge(self, new_requirement_doc, prd_doc, schema=CONFIG.prompt_format) -> Document:
if not CONFIG.project_name:
CONFIG.project_name = Path(CONFIG.project_path).name
prompt = NEW_REQ_TEMPLATE.format(requirements=new_requirement_doc.content, old_prd=prd_doc.content)

View file

@ -109,8 +109,13 @@ class Config(metaclass=Singleton):
if provider is LLMProviderEnum.GEMINI and not require_python_version(req_version=(3, 10)):
warnings.warn("Use Gemini requires Python >= 3.10")
if self.openai_api_key and self.openai_api_model:
logger.info(f"OpenAI API Model: {self.openai_api_model}")
model_mappings = {
LLMProviderEnum.OPENAI: self.OPENAI_API_MODEL,
LLMProviderEnum.AZURE_OPENAI: self.DEPLOYMENT_NAME,
}
model_name = model_mappings.get(provider)
if model_name:
logger.info(f"{provider} Model: {model_name}")
if provider:
logger.info(f"API: {provider}")
return provider
@ -187,6 +192,7 @@ class Config(metaclass=Singleton):
self.workspace_path = self.workspace_path / workspace_uid
self._ensure_workspace_exists()
self.max_auto_summarize_code = self.max_auto_summarize_code or self._get("MAX_AUTO_SUMMARIZE_CODE", 1)
self.timeout = int(self._get("TIMEOUT", 3))
def update_via_cli(self, project_path, project_name, inc, reqa_file, max_auto_summarize_code):
"""update config via cli"""

View file

@ -64,10 +64,6 @@ class AzureOpenAIGPTAPI(OpenAIGPTAPI):
}
if configs:
kwargs.update(configs)
try:
default_timeout = int(CONFIG.TIMEOUT) if CONFIG.TIMEOUT else 0
except ValueError:
default_timeout = 0
kwargs["timeout"] = max(default_timeout, timeout)
kwargs["timeout"] = max(CONFIG.timeout, timeout)
return kwargs

View file

@ -129,7 +129,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
)
async for chunk in response:
chunk_message = chunk.choices[0].delta.content or "" # extract the message
chunk_message = chunk.choices[0].delta.content or "" if chunk.choices else "" # extract the message
yield chunk_message
def _cons_kwargs(self, messages: list[dict], timeout=3, **configs) -> dict:
@ -143,11 +143,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
}
if configs:
kwargs.update(configs)
try:
default_timeout = int(CONFIG.TIMEOUT) if CONFIG.TIMEOUT else 0
except ValueError:
default_timeout = 0
kwargs["timeout"] = max(default_timeout, timeout)
kwargs["timeout"] = max(CONFIG.timeout, timeout)
return kwargs

View file

@ -311,4 +311,4 @@ class Engineer(Role):
@property
def todo(self) -> str:
"""AgentStore uses this attribute to display to the user what actions the current role should take."""
return self._next_todo
return self.next_todo_action

View file

@ -7,7 +7,6 @@
@Modified By: mashenquan, 2023/11/27. Add `PrepareDocuments` action according to Section 2.2.3.5.1 of RFC 135.
"""
from metagpt.actions import UserRequirement, WritePRD
from metagpt.actions.prepare_documents import PrepareDocuments
from metagpt.config import CONFIG
@ -39,14 +38,19 @@ class ProductManager(Role):
self._watch([UserRequirement, PrepareDocuments])
self.todo_action = any_to_name(PrepareDocuments)
async def _think(self) -> None:
async def _think(self) -> bool:
"""Decide what to do"""
if CONFIG.git_repo:
self._set_state(1)
else:
self._set_state(0)
self.todo_action = any_to_name(WritePRD)
return self._rc.todo
return bool(self._rc.todo)
async def _observe(self, ignore_memory=False) -> int:
return await super()._observe(ignore_memory=True)
@property
def todo(self) -> str:
"""AgentStore uses this attribute to display to the user what actions the current role should take."""
return self.todo_action

View file

@ -6,7 +6,7 @@
from __future__ import annotations
import importlib
from typing import Any, Callable, Coroutine, Dict, Literal, overload
from typing import Any, Callable, Coroutine, Literal, overload
from metagpt.config import CONFIG
from metagpt.tools import WebBrowserEngineType
@ -16,7 +16,6 @@ from metagpt.utils.parse_html import WebPage
class WebBrowserEngine:
def __init__(
self,
options: Dict,
engine: WebBrowserEngineType | None = None,
run_func: Callable[..., Coroutine[Any, Any, WebPage | list[WebPage]]] | None = None,
):

View file

@ -8,7 +8,7 @@
from metagpt.config import CONFIG
def get_template(templates, schema=CONFIG.prompt_schema):
def get_template(templates, schema=CONFIG.prompt_format):
selected_templates = templates.get(schema)
if selected_templates is None:
raise ValueError(f"Can't find {schema} in passed in templates")

View file

@ -60,3 +60,4 @@ websockets~=12.0
networkx~=3.2.1
pylint~=3.0.3
google-generativeai==0.3.1
playwright==1.40.0