mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
refine code
This commit is contained in:
parent
f5bb850f25
commit
5cd5eebc5b
12 changed files with 67 additions and 56 deletions
|
|
@ -14,8 +14,6 @@ from pydantic import BaseModel, ConfigDict, Field, model_validator
|
|||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.context import ContextMixin
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import (
|
||||
CodeSummarizeContext,
|
||||
CodingContext,
|
||||
|
|
@ -30,7 +28,6 @@ class Action(SerializationMixin, ContextMixin, BaseModel):
|
|||
model_config = ConfigDict(arbitrary_types_allowed=True, exclude=["llm"])
|
||||
|
||||
name: str = ""
|
||||
llm: BaseLLM = Field(default_factory=LLM, exclude=True)
|
||||
i_context: Union[dict, CodingContext, CodeSummarizeContext, TestingContext, RunCodeContext, str, None] = ""
|
||||
prefix: str = "" # aask*时会加上prefix,作为system_message
|
||||
desc: str = "" # for skill manager
|
||||
|
|
|
|||
|
|
@ -133,7 +133,6 @@ class GenerateTable(Action):
|
|||
|
||||
name: str = "GenerateTable"
|
||||
i_context: Optional[str] = None
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
language: str = "ch"
|
||||
|
||||
async def run(self, ocr_results: list, filename: str, *args, **kwargs) -> dict[str, str]:
|
||||
|
|
|
|||
|
|
@ -178,7 +178,6 @@ class WebBrowseAndSummarize(Action):
|
|||
|
||||
name: str = "WebBrowseAndSummarize"
|
||||
i_context: Optional[str] = None
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
desc: str = "Explore the web and provide summaries of articles and webpages."
|
||||
browse_func: Union[Callable[[list[str]], None], None] = None
|
||||
web_browser_engine: Optional[WebBrowserEngine] = None
|
||||
|
|
|
|||
|
|
@ -42,28 +42,6 @@ class AttrDict(BaseModel):
|
|||
raise AttributeError(f"No such attribute: {key}")
|
||||
|
||||
|
||||
class LLMInstance:
|
||||
"""Mixin class for LLM"""
|
||||
|
||||
# _config: Optional[Config] = None
|
||||
_llm_config: Optional[LLMConfig] = None
|
||||
_llm_instance: Optional[BaseLLM] = None
|
||||
|
||||
def __init__(self, config: Config, name: Optional[str] = None, provider: LLMType = LLMType.OPENAI):
|
||||
"""Use a LLM provider"""
|
||||
# 更新LLM配置
|
||||
self._llm_config = config.get_llm_config(name, provider)
|
||||
# 重置LLM实例
|
||||
self._llm_instance = None
|
||||
|
||||
@property
|
||||
def instance(self) -> BaseLLM:
|
||||
"""Return the LLM instance"""
|
||||
if not self._llm_instance and self._llm_config:
|
||||
self._llm_instance = create_llm_instance(self._llm_config)
|
||||
return self._llm_instance
|
||||
|
||||
|
||||
class Context(BaseModel):
|
||||
"""Env context for MetaGPT"""
|
||||
|
||||
|
|
@ -74,7 +52,8 @@ class Context(BaseModel):
|
|||
git_repo: Optional[GitRepository] = None
|
||||
src_workspace: Optional[Path] = None
|
||||
cost_manager: CostManager = CostManager()
|
||||
_llm: Optional[LLMInstance] = None
|
||||
|
||||
_llm: Optional[BaseLLM] = None
|
||||
|
||||
@property
|
||||
def file_repo(self):
|
||||
|
|
@ -92,12 +71,19 @@ class Context(BaseModel):
|
|||
env.update({k: v for k, v in i.items() if isinstance(v, str)})
|
||||
return env
|
||||
|
||||
# def use_llm(self, name: Optional[str] = None, provider: LLMType = LLMType.OPENAI) -> BaseLLM:
|
||||
# """Use a LLM instance"""
|
||||
# self._llm_config = self.config.get_llm_config(name, provider)
|
||||
# self._llm = None
|
||||
# return self._llm
|
||||
|
||||
def llm(self, name: Optional[str] = None, provider: LLMType = LLMType.OPENAI) -> BaseLLM:
|
||||
"""Return a LLM instance"""
|
||||
llm = LLMInstance(self.config, name, provider).instance
|
||||
if llm.cost_manager is None:
|
||||
llm.cost_manager = self.cost_manager
|
||||
return llm
|
||||
"""Return a LLM instance, fixme: support multiple llm instances"""
|
||||
if self._llm is None:
|
||||
self._llm = create_llm_instance(self.config.get_llm_config(name, provider))
|
||||
if self._llm.cost_manager is None:
|
||||
self._llm.cost_manager = self.cost_manager
|
||||
return self._llm
|
||||
|
||||
|
||||
class ContextMixin(BaseModel):
|
||||
|
|
@ -108,11 +94,22 @@ class ContextMixin(BaseModel):
|
|||
# Env/Role/Action will use this config as private config, or use self.context.config as public config
|
||||
_config: Optional[Config] = None
|
||||
|
||||
def __init__(self, context: Optional[Context] = None, config: Optional[Config] = None, **kwargs):
|
||||
# Env/Role/Action will use this llm as private llm, or use self.context._llm instance
|
||||
_llm_config: Optional[LLMConfig] = None
|
||||
_llm: Optional[BaseLLM] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: Optional[Context] = None,
|
||||
config: Optional[Config] = None,
|
||||
llm: Optional[BaseLLM] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize with config"""
|
||||
super().__init__(**kwargs)
|
||||
self.set_context(context)
|
||||
self.set_config(config)
|
||||
self.set_llm(llm)
|
||||
|
||||
def set(self, k, v, override=False):
|
||||
"""Set attribute"""
|
||||
|
|
@ -127,30 +124,56 @@ class ContextMixin(BaseModel):
|
|||
"""Set config"""
|
||||
self.set("_config", config, override)
|
||||
|
||||
def set_llm_config(self, llm_config: LLMConfig, override=False):
|
||||
"""Set llm config"""
|
||||
self.set("_llm_config", llm_config, override)
|
||||
|
||||
def set_llm(self, llm: BaseLLM, override=False):
|
||||
"""Set llm"""
|
||||
self.set("_llm", llm, override)
|
||||
|
||||
def use_llm(self, name: Optional[str] = None, provider: LLMType = LLMType.OPENAI) -> BaseLLM:
|
||||
"""Use a LLM instance"""
|
||||
self._llm_config = self.config.get_llm_config(name, provider)
|
||||
self._llm = None
|
||||
return self.llm
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
def config(self) -> Config:
|
||||
"""Role config: role config > context config"""
|
||||
if self._config:
|
||||
return self._config
|
||||
return self.context.config
|
||||
|
||||
@config.setter
|
||||
def config(self, config: Config):
|
||||
def config(self, config: Config) -> None:
|
||||
"""Set config"""
|
||||
self.set_config(config)
|
||||
|
||||
@property
|
||||
def context(self):
|
||||
def context(self) -> Context:
|
||||
"""Role context: role context > context"""
|
||||
if self._context:
|
||||
return self._context
|
||||
return CONTEXT
|
||||
|
||||
@context.setter
|
||||
def context(self, context: Context):
|
||||
def context(self, context: Context) -> None:
|
||||
"""Set context"""
|
||||
self.set_context(context)
|
||||
|
||||
@property
|
||||
def llm(self) -> BaseLLM:
|
||||
"""Role llm: role llm > context llm"""
|
||||
if self._llm_config and not self._llm:
|
||||
self._llm = self.context.llm(self._llm_config.name, self._llm_config.provider)
|
||||
return self._llm or self.context.llm()
|
||||
|
||||
@llm.setter
|
||||
def llm(self, llm: BaseLLM) -> None:
|
||||
"""Set llm"""
|
||||
self._llm = llm
|
||||
|
||||
|
||||
# Global context, not in Env
|
||||
CONTEXT = Context()
|
||||
|
|
|
|||
|
|
@ -109,7 +109,7 @@ class Engineer(Role):
|
|||
coding_context = await todo.run()
|
||||
# Code review
|
||||
if review:
|
||||
action = WriteCodeReview(context=coding_context, g_context=self.context, llm=self.llm)
|
||||
action = WriteCodeReview(context=coding_context, _context=self.context, llm=self.llm)
|
||||
self._init_action_system_message(action)
|
||||
coding_context = await action.run()
|
||||
await src_file_repo.save(
|
||||
|
|
|
|||
|
|
@ -31,11 +31,9 @@ from metagpt.actions import Action, ActionOutput
|
|||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.add_requirement import UserRequirement
|
||||
from metagpt.context import ContextMixin
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.memory import Memory
|
||||
from metagpt.provider import HumanProvider
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Message, MessageQueue, SerializationMixin
|
||||
from metagpt.utils.common import any_to_name, any_to_str, role_raise_decorator
|
||||
from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output
|
||||
|
|
@ -131,7 +129,6 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
|
|||
desc: str = ""
|
||||
is_human: bool = False
|
||||
|
||||
llm: BaseLLM = Field(default_factory=LLM, exclude=True) # Each role has its own LLM, use different system message
|
||||
role_id: str = ""
|
||||
states: list[str] = []
|
||||
actions: list[SerializeAsAny[Action]] = Field(default=[], validate_default=True)
|
||||
|
|
|
|||
|
|
@ -17,9 +17,7 @@ from semantic_kernel.planning.basic_planner import BasicPlanner, Plan
|
|||
|
||||
from metagpt.actions import UserRequirement
|
||||
from metagpt.actions.execute_task import ExecuteTask
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.roles import Role
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.make_sk_kernel import make_sk_kernel
|
||||
|
|
@ -44,7 +42,6 @@ class SkAgent(Role):
|
|||
plan: Plan = Field(default=None, exclude=True)
|
||||
planner_cls: Any = None
|
||||
planner: Union[BasicPlanner, SequentialPlanner, ActionPlanner] = None
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
kernel: Kernel = Field(default_factory=Kernel)
|
||||
import_semantic_skill_from_directory: Callable = Field(default=None, exclude=True)
|
||||
import_skill: Callable = Field(default=None, exclude=True)
|
||||
|
|
|
|||
|
|
@ -7,12 +7,12 @@
|
|||
"""
|
||||
from typing import Union
|
||||
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
|
||||
|
||||
class Moderation:
|
||||
def __init__(self):
|
||||
self.llm = LLM()
|
||||
def __init__(self, llm: BaseLLM):
|
||||
self.llm = llm
|
||||
|
||||
def handle_moderation_results(self, results):
|
||||
resp = []
|
||||
|
|
|
|||
|
|
@ -16,9 +16,6 @@ from metagpt.provider.base_llm import BaseLLM
|
|||
|
||||
class OpenAIText2Image:
|
||||
def __init__(self, llm: BaseLLM):
|
||||
"""
|
||||
:param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`
|
||||
"""
|
||||
self.llm = llm
|
||||
|
||||
async def text_2_image(self, text, size_type="1024x1024"):
|
||||
|
|
|
|||
|
|
@ -79,3 +79,6 @@ def test_config_mixin_3():
|
|||
assert obj.b == "b"
|
||||
assert obj.c == "c"
|
||||
assert obj.d == "d"
|
||||
|
||||
print(obj.__dict__.keys())
|
||||
assert "_config" in obj.__dict__.keys()
|
||||
|
|
|
|||
|
|
@ -66,7 +66,5 @@ def test_context_2():
|
|||
def test_context_3():
|
||||
ctx = Context()
|
||||
ctx.use_llm(provider=LLMType.OPENAI)
|
||||
assert ctx.llm_config is not None
|
||||
assert ctx.llm_config.api_type == LLMType.OPENAI
|
||||
assert ctx.llm is not None
|
||||
assert "gpt" in ctx.llm.model
|
||||
assert ctx.llm() is not None
|
||||
assert "gpt" in ctx.llm().model
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.context import CONTEXT
|
||||
from metagpt.tools.moderation import Moderation
|
||||
|
||||
|
||||
|
|
@ -27,7 +28,7 @@ async def test_amoderation(content):
|
|||
assert not CONFIG.OPENAI_API_TYPE
|
||||
assert CONFIG.OPENAI_API_MODEL
|
||||
|
||||
moderation = Moderation()
|
||||
moderation = Moderation(CONTEXT.llm())
|
||||
results = await moderation.amoderation(content=content)
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == len(content)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue