fix llm set

This commit is contained in:
geekan 2024-01-11 21:48:42 +08:00
parent f59449d5d2
commit 6e44b2b515
8 changed files with 38 additions and 24 deletions

View file

@ -13,7 +13,9 @@ from metagpt.roles import Role
from metagpt.team import Team
action1 = Action(name="AlexSay", instruction="Express your opinion with emotion and don't repeat it")
action1.llm.model = "gpt-4-1106-preview"
action2 = Action(name="BobSay", instruction="Express your opinion with emotion and don't repeat it")
action2.llm.model = "gpt-3.5-turbo-1106"
alex = Role(name="Alex", profile="Democratic candidate", goal="Win the election", actions=[action1], watch=[action2])
bob = Role(name="Bob", profile="Republican candidate", goal="Win the election", actions=[action2], watch=[action1])
env = Environment(desc="US election live broadcast")

View file

@ -25,7 +25,7 @@ from metagpt.utils.file_repository import FileRepository
class Action(SerializationMixin, ContextMixin, BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, exclude=["llm"])
model_config = ConfigDict(arbitrary_types_allowed=True)
name: str = ""
i_context: Union[dict, CodingContext, CodeSummarizeContext, TestingContext, RunCodeContext, str, None] = ""

View file

@ -57,6 +57,8 @@ class ContextMixin(BaseModel):
def set_config(self, config: Config, override=False):
"""Set config"""
self.set("private_config", config, override)
if config is not None:
_ = self.llm # init llm
def set_llm(self, llm: BaseLLM, override=False):
"""Set llm"""

View file

@ -220,10 +220,12 @@ class OpenAILLM(BaseLLM):
@handle_exception
def _update_costs(self, usage: CompletionUsage):
if self.config.calc_usage and usage:
if self.config.calc_usage and usage and self.cost_manager:
self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
def get_costs(self) -> Costs:
if not self.cost_manager:
return Costs()
return self.cost_manager.get_costs()
def _get_max_tokens(self, messages: list[dict]):

View file

@ -131,6 +131,13 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
role_id: str = ""
states: list[str] = []
# scenarios to set action system_prompt:
# 1. `__init__` while using Role(actions=[...])
# 2. add action to role while using `role.set_action(action)`
# 3. set_todo while using `role.set_todo(action)`
# 4. when role.system_prompt is being updated (e.g. by `role.system_prompt = "..."`)
# Additional, if llm is not set, we will use role's llm
actions: list[SerializeAsAny[Action]] = Field(default=[], validate_default=True)
rc: RoleContext = Field(default_factory=RoleContext)
addresses: set[str] = set()
@ -222,6 +229,12 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
def _setting(self):
return f"{self.name}({self.profile})"
@model_validator(mode="after")
def _check_actions(self):
"""Check actions and set llm and prefix for each action."""
self.set_actions(self.actions)
return self
def _init_action(self, action: Action):
action.set_llm(self.llm, override=False)
action.set_prefix(self._get_prefix())
@ -306,6 +319,7 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
if env:
env.set_addresses(self, self.addresses)
self.llm.system_prompt = self._get_prefix()
self.set_actions(self.actions) # reset actions to update llm and prefix
def _get_prefix(self):
"""Get the role prefix"""
@ -318,7 +332,8 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
prefix += CONSTRAINT_TEMPLATE.format(**{"constraints": self.constraints})
if self.rc.env and self.rc.env.desc:
other_role_names = ", ".join(self.rc.env.role_names())
all_roles = self.rc.env.role_names()
other_role_names = ", ".join([r for r in all_roles if r != self.name])
env_desc = f"You are in {self.rc.env.desc} with roles({other_role_names})."
prefix += env_desc
return prefix
@ -478,7 +493,6 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
if not msg.cause_by:
msg.cause_by = UserRequirement
self.put_message(msg)
if not await self._observe():
# If there is no new information, suspend and wait
logger.debug(f"{self._setting}: no news. waiting.")

File diff suppressed because one or more lines are too long

View file

@ -23,14 +23,12 @@ from metagpt.team import Team
async def test_debate_two_roles():
action1 = Action(name="AlexSay", instruction="Express your opinion with emotion and don't repeat it")
action2 = Action(name="BobSay", instruction="Express your opinion with emotion and don't repeat it")
biden = Role(
alex = Role(
name="Alex", profile="Democratic candidate", goal="Win the election", actions=[action1], watch=[action2]
)
trump = Role(
name="Bob", profile="Republican candidate", goal="Win the election", actions=[action2], watch=[action1]
)
bob = Role(name="Bob", profile="Republican candidate", goal="Win the election", actions=[action2], watch=[action1])
env = Environment(desc="US election live broadcast")
team = Team(investment=10.0, env=env, roles=[biden, trump])
team = Team(investment=10.0, env=env, roles=[alex, bob])
history = await team.run(idea="Topic: climate change. Under 80 words per message.", send_to="Alex", n_round=3)
assert "Alex" in history
@ -39,9 +37,9 @@ async def test_debate_two_roles():
@pytest.mark.asyncio
async def test_debate_one_role_in_env():
action = Action(name="Debate", instruction="Express your opinion with emotion and don't repeat it")
biden = Role(name="Alex", profile="Democratic candidate", goal="Win the election", actions=[action])
alex = Role(name="Alex", profile="Democratic candidate", goal="Win the election", actions=[action])
env = Environment(desc="US election live broadcast")
team = Team(investment=10.0, env=env, roles=[biden])
team = Team(investment=10.0, env=env, roles=[alex])
history = await team.run(idea="Topic: climate change. Under 80 words per message.", send_to="Alex", n_round=3)
assert "Alex" in history
@ -49,8 +47,8 @@ async def test_debate_one_role_in_env():
@pytest.mark.asyncio
async def test_debate_one_role():
action = Action(name="Debate", instruction="Express your opinion with emotion and don't repeat it")
biden = Role(name="Alex", profile="Democratic candidate", goal="Win the election", actions=[action])
msg: Message = await biden.run("Topic: climate change. Under 80 words per message.")
alex = Role(name="Alex", profile="Democratic candidate", goal="Win the election", actions=[action])
msg: Message = await alex.run("Topic: climate change. Under 80 words per message.")
assert len(msg.content) > 10
assert msg.sent_from == "metagpt.roles.role.Role"

View file

@ -104,19 +104,12 @@ async def test_debate_two_roles():
config.llm.model = "gpt-4-1106-preview"
action1 = Action(config=config, name="AlexSay", instruction="Say your opinion with emotion and don't repeat it")
action2 = Action(name="BobSay", instruction="Say your opinion with emotion and don't repeat it")
biden = Role(
alex = Role(
name="Alex", profile="Democratic candidate", goal="Win the election", actions=[action1], watch=[action2]
)
trump = Role(
name="Bob", profile="Republican candidate", goal="Win the election", actions=[action2], watch=[action1]
)
bob = Role(name="Bob", profile="Republican candidate", goal="Win the election", actions=[action2], watch=[action1])
env = Environment(desc="US election live broadcast")
team = Team(investment=10.0, env=env, roles=[biden, trump])
print(action1.llm.system_prompt)
print(action2.llm.system_prompt)
print(biden.llm.system_prompt)
print(trump.llm.system_prompt)
team = Team(investment=10.0, env=env, roles=[alex, bob])
history = await team.run(idea="Topic: climate change. Under 80 words per message.", send_to="Alex", n_round=3)
assert "Alex" in history