support spark

This commit is contained in:
geekan 2024-01-11 20:51:27 +08:00
parent b9b268ad8b
commit f59449d5d2
5 changed files with 50 additions and 8 deletions

View file

@ -5,13 +5,15 @@
@Author : alexanderwu
@File : llm.py
"""
from typing import Optional
from metagpt.configs.llm_config import LLMConfig
from metagpt.context import CONTEXT
from metagpt.provider.base_llm import BaseLLM
def LLM() -> BaseLLM:
def LLM(llm_config: Optional[LLMConfig] = None) -> BaseLLM:
"""get the default llm provider if name is None"""
# context.use_llm(name=name, provider=provider)
if llm_config is not None:
CONTEXT.llm_with_cost_manager_from_llm_config(llm_config)
return CONTEXT.llm()

View file

@ -110,7 +110,7 @@ class Engineer(Role):
# Code review
if review:
action = WriteCodeReview(i_context=coding_context, context=self.context, llm=self.llm)
self._init_action_system_message(action)
self._init_action(action)
coding_context = await action.run()
await src_file_repo.save(
coding_context.filename,

View file

@ -146,7 +146,7 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
super().__init__(**data)
if self.is_human:
self.llm = HumanProvider()
self.llm = HumanProvider(None)
self.llm.system_prompt = self._get_prefix()
self._watch(data.get("watch") or [UserRequirement])
@ -222,7 +222,8 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
def _setting(self):
return f"{self.name}({self.profile})"
def _init_action_system_message(self, action: Action):
def _init_action(self, action: Action):
action.set_llm(self.llm, override=False)
action.set_prefix(self._get_prefix())
def set_action(self, action: Action):
@ -238,7 +239,7 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
self._reset()
for action in actions:
if not isinstance(action, Action):
i = action(name="", llm=self.llm)
i = action()
else:
if self.is_human and not isinstance(action.llm, HumanProvider):
logger.warning(
@ -247,7 +248,7 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
f"try passing in Action classes instead of initialized instances"
)
i = action
self._init_action_system_message(i)
self._init_action(i)
self.actions.append(i)
self.states.append(f"{len(self.actions)}. {action}")