refine code

This commit is contained in:
geekan 2024-01-10 15:34:49 +08:00
parent f5bb850f25
commit 5cd5eebc5b
12 changed files with 67 additions and 56 deletions

View file

@ -14,8 +14,6 @@ from pydantic import BaseModel, ConfigDict, Field, model_validator
from metagpt.actions.action_node import ActionNode
from metagpt.context import ContextMixin
from metagpt.llm import LLM
from metagpt.provider.base_llm import BaseLLM
from metagpt.schema import (
CodeSummarizeContext,
CodingContext,
@ -30,7 +28,6 @@ class Action(SerializationMixin, ContextMixin, BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, exclude=["llm"])
name: str = ""
llm: BaseLLM = Field(default_factory=LLM, exclude=True)
i_context: Union[dict, CodingContext, CodeSummarizeContext, TestingContext, RunCodeContext, str, None] = ""
prefix: str = "" # aask*时会加上prefix作为system_message
desc: str = "" # for skill manager

View file

@ -133,7 +133,6 @@ class GenerateTable(Action):
name: str = "GenerateTable"
i_context: Optional[str] = None
llm: BaseLLM = Field(default_factory=LLM)
language: str = "ch"
async def run(self, ocr_results: list, filename: str, *args, **kwargs) -> dict[str, str]:

View file

@ -178,7 +178,6 @@ class WebBrowseAndSummarize(Action):
name: str = "WebBrowseAndSummarize"
i_context: Optional[str] = None
llm: BaseLLM = Field(default_factory=LLM)
desc: str = "Explore the web and provide summaries of articles and webpages."
browse_func: Union[Callable[[list[str]], None], None] = None
web_browser_engine: Optional[WebBrowserEngine] = None

View file

@ -42,28 +42,6 @@ class AttrDict(BaseModel):
raise AttributeError(f"No such attribute: {key}")
class LLMInstance:
"""Mixin class for LLM"""
# _config: Optional[Config] = None
_llm_config: Optional[LLMConfig] = None
_llm_instance: Optional[BaseLLM] = None
def __init__(self, config: Config, name: Optional[str] = None, provider: LLMType = LLMType.OPENAI):
"""Use a LLM provider"""
# 更新LLM配置
self._llm_config = config.get_llm_config(name, provider)
# 重置LLM实例
self._llm_instance = None
@property
def instance(self) -> BaseLLM:
"""Return the LLM instance"""
if not self._llm_instance and self._llm_config:
self._llm_instance = create_llm_instance(self._llm_config)
return self._llm_instance
class Context(BaseModel):
"""Env context for MetaGPT"""
@ -74,7 +52,8 @@ class Context(BaseModel):
git_repo: Optional[GitRepository] = None
src_workspace: Optional[Path] = None
cost_manager: CostManager = CostManager()
_llm: Optional[LLMInstance] = None
_llm: Optional[BaseLLM] = None
@property
def file_repo(self):
@ -92,12 +71,19 @@ class Context(BaseModel):
env.update({k: v for k, v in i.items() if isinstance(v, str)})
return env
# def use_llm(self, name: Optional[str] = None, provider: LLMType = LLMType.OPENAI) -> BaseLLM:
# """Use a LLM instance"""
# self._llm_config = self.config.get_llm_config(name, provider)
# self._llm = None
# return self._llm
def llm(self, name: Optional[str] = None, provider: LLMType = LLMType.OPENAI) -> BaseLLM:
"""Return a LLM instance"""
llm = LLMInstance(self.config, name, provider).instance
if llm.cost_manager is None:
llm.cost_manager = self.cost_manager
return llm
"""Return a LLM instance, fixme: support multiple llm instances"""
if self._llm is None:
self._llm = create_llm_instance(self.config.get_llm_config(name, provider))
if self._llm.cost_manager is None:
self._llm.cost_manager = self.cost_manager
return self._llm
class ContextMixin(BaseModel):
@ -108,11 +94,22 @@ class ContextMixin(BaseModel):
# Env/Role/Action will use this config as private config, or use self.context.config as public config
_config: Optional[Config] = None
def __init__(self, context: Optional[Context] = None, config: Optional[Config] = None, **kwargs):
# Env/Role/Action will use this llm as private llm, or use self.context._llm instance
_llm_config: Optional[LLMConfig] = None
_llm: Optional[BaseLLM] = None
def __init__(
self,
context: Optional[Context] = None,
config: Optional[Config] = None,
llm: Optional[BaseLLM] = None,
**kwargs,
):
"""Initialize with config"""
super().__init__(**kwargs)
self.set_context(context)
self.set_config(config)
self.set_llm(llm)
def set(self, k, v, override=False):
"""Set attribute"""
@ -127,30 +124,56 @@ class ContextMixin(BaseModel):
"""Set config"""
self.set("_config", config, override)
def set_llm_config(self, llm_config: LLMConfig, override=False):
"""Set llm config"""
self.set("_llm_config", llm_config, override)
def set_llm(self, llm: BaseLLM, override=False):
"""Set llm"""
self.set("_llm", llm, override)
def use_llm(self, name: Optional[str] = None, provider: LLMType = LLMType.OPENAI) -> BaseLLM:
"""Use a LLM instance"""
self._llm_config = self.config.get_llm_config(name, provider)
self._llm = None
return self.llm
@property
def config(self):
def config(self) -> Config:
"""Role config: role config > context config"""
if self._config:
return self._config
return self.context.config
@config.setter
def config(self, config: Config):
def config(self, config: Config) -> None:
"""Set config"""
self.set_config(config)
@property
def context(self):
def context(self) -> Context:
"""Role context: role context > context"""
if self._context:
return self._context
return CONTEXT
@context.setter
def context(self, context: Context):
def context(self, context: Context) -> None:
"""Set context"""
self.set_context(context)
@property
def llm(self) -> BaseLLM:
"""Role llm: role llm > context llm"""
if self._llm_config and not self._llm:
self._llm = self.context.llm(self._llm_config.name, self._llm_config.provider)
return self._llm or self.context.llm()
@llm.setter
def llm(self, llm: BaseLLM) -> None:
"""Set llm"""
self._llm = llm
# Global context, not in Env
CONTEXT = Context()

View file

@ -109,7 +109,7 @@ class Engineer(Role):
coding_context = await todo.run()
# Code review
if review:
action = WriteCodeReview(context=coding_context, g_context=self.context, llm=self.llm)
action = WriteCodeReview(context=coding_context, _context=self.context, llm=self.llm)
self._init_action_system_message(action)
coding_context = await action.run()
await src_file_repo.save(

View file

@ -31,11 +31,9 @@ from metagpt.actions import Action, ActionOutput
from metagpt.actions.action_node import ActionNode
from metagpt.actions.add_requirement import UserRequirement
from metagpt.context import ContextMixin
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.memory import Memory
from metagpt.provider import HumanProvider
from metagpt.provider.base_llm import BaseLLM
from metagpt.schema import Message, MessageQueue, SerializationMixin
from metagpt.utils.common import any_to_name, any_to_str, role_raise_decorator
from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output
@ -131,7 +129,6 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
desc: str = ""
is_human: bool = False
llm: BaseLLM = Field(default_factory=LLM, exclude=True) # Each role has its own LLM, use different system message
role_id: str = ""
states: list[str] = []
actions: list[SerializeAsAny[Action]] = Field(default=[], validate_default=True)

View file

@ -17,9 +17,7 @@ from semantic_kernel.planning.basic_planner import BasicPlanner, Plan
from metagpt.actions import UserRequirement
from metagpt.actions.execute_task import ExecuteTask
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.provider.base_llm import BaseLLM
from metagpt.roles import Role
from metagpt.schema import Message
from metagpt.utils.make_sk_kernel import make_sk_kernel
@ -44,7 +42,6 @@ class SkAgent(Role):
plan: Plan = Field(default=None, exclude=True)
planner_cls: Any = None
planner: Union[BasicPlanner, SequentialPlanner, ActionPlanner] = None
llm: BaseLLM = Field(default_factory=LLM)
kernel: Kernel = Field(default_factory=Kernel)
import_semantic_skill_from_directory: Callable = Field(default=None, exclude=True)
import_skill: Callable = Field(default=None, exclude=True)

View file

@ -7,12 +7,12 @@
"""
from typing import Union
from metagpt.llm import LLM
from metagpt.provider.base_llm import BaseLLM
class Moderation:
def __init__(self):
self.llm = LLM()
def __init__(self, llm: BaseLLM):
self.llm = llm
def handle_moderation_results(self, results):
resp = []

View file

@ -16,9 +16,6 @@ from metagpt.provider.base_llm import BaseLLM
class OpenAIText2Image:
def __init__(self, llm: BaseLLM):
"""
:param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`
"""
self.llm = llm
async def text_2_image(self, text, size_type="1024x1024"):

View file

@ -79,3 +79,6 @@ def test_config_mixin_3():
assert obj.b == "b"
assert obj.c == "c"
assert obj.d == "d"
print(obj.__dict__.keys())
assert "_config" in obj.__dict__.keys()

View file

@ -66,7 +66,5 @@ def test_context_2():
def test_context_3():
ctx = Context()
ctx.use_llm(provider=LLMType.OPENAI)
assert ctx.llm_config is not None
assert ctx.llm_config.api_type == LLMType.OPENAI
assert ctx.llm is not None
assert "gpt" in ctx.llm.model
assert ctx.llm() is not None
assert "gpt" in ctx.llm().model

View file

@ -9,6 +9,7 @@
import pytest
from metagpt.config import CONFIG
from metagpt.context import CONTEXT
from metagpt.tools.moderation import Moderation
@ -27,7 +28,7 @@ async def test_amoderation(content):
assert not CONFIG.OPENAI_API_TYPE
assert CONFIG.OPENAI_API_MODEL
moderation = Moderation()
moderation = Moderation(CONTEXT.llm())
results = await moderation.amoderation(content=content)
assert isinstance(results, list)
assert len(results) == len(content)