This commit is contained in:
geekan 2024-01-10 18:32:03 +08:00 committed by 莘权 马
parent 0d742654d4
commit ae0a91c025
4 changed files with 17 additions and 11 deletions

View file

@ -132,7 +132,7 @@ class WriteCode(Action):
code = await self.write_code(prompt)
if not coding_context.code_doc:
# avoid root_path pydantic ValidationError if use WriteCode alone
root_path = self.i_context.src_workspace if self.i_context.src_workspace else ""
root_path = self.context.src_workspace if self.context.src_workspace else ""
coding_context.code_doc = Document(filename=coding_context.filename, root_path=str(root_path))
coding_context.code_doc.content = code
return coding_context

View file

@ -121,12 +121,10 @@ class Config(CLIParams, YamlModel):
return llm[0]
return None
def get_llm_config(self, name: Optional[str] = None, provider: LLMType = LLMType.OPENAI) -> LLMConfig:
def get_llm_config(self, name: Optional[str] = None, provider: LLMType = None) -> LLMConfig:
"""Return a LLMConfig instance"""
if provider:
llm_configs = self.get_llm_configs_by_type(provider)
if name:
llm_configs = [c for c in llm_configs if c.name == name]
if len(llm_configs) == 0:
raise ValueError(f"Cannot find llm config with name {name} and provider {provider}")

View file

@ -77,7 +77,7 @@ class Context(BaseModel):
# self._llm = None
# return self._llm
def llm(self, name: Optional[str] = None, provider: LLMType = LLMType.OPENAI) -> BaseLLM:
def llm(self, name: Optional[str] = None, provider: LLMType = None) -> BaseLLM:
"""Return a LLM instance, fixme: support cache"""
# if self._llm is None:
self._llm = create_llm_instance(self.config.get_llm_config(name, provider))
@ -85,6 +85,14 @@ class Context(BaseModel):
self._llm.cost_manager = self.cost_manager
return self._llm
def llm_with_cost_manager_from_llm_config(self, llm_config: LLMConfig) -> BaseLLM:
"""Return a LLM instance, fixme: support cache"""
# if self._llm is None:
llm = create_llm_instance(llm_config)
if llm.cost_manager is None:
llm.cost_manager = self.cost_manager
return llm
class ContextMixin(BaseModel):
"""Mixin class for context and config"""
@ -132,7 +140,7 @@ class ContextMixin(BaseModel):
"""Set llm"""
self.set("_llm", llm, override)
def use_llm(self, name: Optional[str] = None, provider: LLMType = LLMType.OPENAI) -> BaseLLM:
def use_llm(self, name: Optional[str] = None, provider: LLMType = None) -> BaseLLM:
"""Use a LLM instance"""
self._llm_config = self.config.get_llm_config(name, provider)
self._llm = None
@ -165,9 +173,9 @@ class ContextMixin(BaseModel):
@property
def llm(self) -> BaseLLM:
"""Role llm: role llm > context llm"""
# print(f"class:{self.__class__.__name__}({self.name}), llm: {self._llm}, llm_config: {self._llm_config}")
print(f"class:{self.__class__.__name__}({self.name}), llm: {self._llm}, llm_config: {self._llm_config}")
if self._llm_config and not self._llm:
self._llm = self.context.llm(self._llm_config.name, self._llm_config.provider)
self._llm = self.context.llm_with_cost_manager_from_llm_config(self._llm_config)
return self._llm or self.context.llm()
@llm.setter

View file

@ -19,8 +19,8 @@ from metagpt.const import (
TEST_OUTPUTS_FILE_REPO,
)
from metagpt.context import CONTEXT
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.provider.openai_api import OpenAILLM as LLM
from metagpt.schema import CodingContext, Document
from metagpt.utils.common import aread
from tests.metagpt.actions.mock_markdown import TASKS_2, WRITE_CODE_PROMPT_SAMPLE
@ -32,7 +32,7 @@ async def test_write_code():
filename="task_filename.py", design_doc=Document(content="设计一个名为'add'的函数,该函数接受两个整数作为输入,并返回它们的和。")
)
doc = Document(content=ccontext.model_dump_json())
write_code = WriteCode(context=doc)
write_code = WriteCode(i_context=doc)
code = await write_code.run()
logger.info(code.model_dump_json())
@ -86,7 +86,7 @@ async def test_write_code_deps():
)
coding_doc = Document(root_path="snake1", filename="game.py", content=ccontext.json())
action = WriteCode(context=coding_doc)
action = WriteCode(i_context=coding_doc)
rsp = await action.run()
assert rsp
assert rsp.code_doc.content