RUN pre-commit

This commit is contained in:
stellahsr 2024-01-14 17:11:21 +08:00
parent 2829964326
commit 379ae511ee
2 changed files with 22 additions and 22 deletions

View file

@ -4,6 +4,7 @@
# @Desc :
import re
from examples.tot.prompt_templates.creative_writing import cot_prompt, vote_prompt
from metagpt.strategy.tot import TreeofThought
from metagpt.strategy.tot_schema import (
BaseEvaluator,
@ -11,19 +12,18 @@ from metagpt.strategy.tot_schema import (
Strategy,
ThoughtSolverConfig,
)
from examples.tot.prompt_templates.creative_writing import cot_prompt, vote_prompt
class TextGenParser(BaseParser):
propose_prompt: str = cot_prompt
value_prompt: str = vote_prompt
def __call__(self, input_text: str) -> str:
return input_text
def propose(self, current_state: str, **kwargs) -> str:
return self.propose_prompt.format(input=current_state, **kwargs)
def value(self, input: str = "", **kwargs) -> str:
# node_result = self(input)
id = kwargs.get("node_id", "0")
@ -33,14 +33,14 @@ class TextGenParser(BaseParser):
class TextGenEvaluator(BaseEvaluator):
value_map: dict = {"impossible": 0.001, "likely": 1, "sure": 20} # TODO: ad hoc
status_map: dict = {val: key for key, val in value_map.items()}
def __call__(self, evaluation: str, **kwargs) -> float:
try:
value = 0
node_id = kwargs.get("node_id", "0")
pattern = r".*best choice is .*(\d+).*"
match = re.match(pattern, evaluation, re.DOTALL)
if match:
vote = int(match.groups()[0])
print(vote)
@ -49,7 +49,7 @@ class TextGenEvaluator(BaseEvaluator):
except:
value = 0
return value
def status_verify(self, value):
status = False
if value in self.status_map:
@ -61,13 +61,13 @@ class TextGenEvaluator(BaseEvaluator):
if __name__ == "__main__":
import asyncio
initial_prompt = """It isn't difficult to do a handstand if you just stand on your hands. It caught him off guard that space smelled of seared steak. When she didnt like a guy who was trying to pick her up, she started using sign language. Each person who knows you has a different perception of who you are."""
parser = TextGenParser()
evaluator = TextGenEvaluator()
config = ThoughtSolverConfig(n_generate_sample=3, parser=parser, evaluator=evaluator)
tot_base = TreeofThought(strategy=Strategy.BFS, config=config)
asyncio.run(tot_base.solve(init_prompt=initial_prompt))

View file

@ -4,6 +4,7 @@
# @Desc :
import re
from examples.tot.prompt_templates.game24 import propose_prompt, value_prompt
from metagpt.strategy.tot import TreeofThought
from metagpt.strategy.tot_schema import (
BaseEvaluator,
@ -11,29 +12,28 @@ from metagpt.strategy.tot_schema import (
Strategy,
ThoughtSolverConfig,
)
from examples.tot.prompt_templates.game24 import propose_prompt, value_prompt
class Game24Parser(BaseParser):
propose_prompt: str = propose_prompt
value_prompt: str = value_prompt
def __call__(self, input_text: str) -> str:
last_line = input_text.strip().split("\n")[-1]
return last_line.split("left: ")[-1].split(")")[0]
def propose(self, current_state: str, **kwargs) -> str:
return self.propose_prompt.format(input=current_state, **kwargs)
def value(self, input: str = "", **kwargs) -> str:
node_result = self(input)
return self.value_prompt.format(input=node_result)
class Game24Evaluator(BaseEvaluator):
value_map : dict = {"impossible": 0.001, "likely": 1, "sure": 20} # TODO: ad hoc
status_map : dict = {val: key for key, val in value_map.items()}
value_map: dict = {"impossible": 0.001, "likely": 1, "sure": 20} # TODO: ad hoc
status_map: dict = {val: key for key, val in value_map.items()}
def __call__(self, evaluation: str, **kwargs) -> float:
try:
matches = re.findall(r"\b(impossible|sure|likely)\b", evaluation)
@ -41,7 +41,7 @@ class Game24Evaluator(BaseEvaluator):
except:
value = 0.001
return value
def status_verify(self, value):
status = False
if value in self.status_map:
@ -53,12 +53,12 @@ class Game24Evaluator(BaseEvaluator):
if __name__ == "__main__":
import asyncio
initial_prompt = """4 5 6 10"""
parser = Game24Parser()
evaluator = Game24Evaluator()
config = ThoughtSolverConfig(n_generate_sample=5, parser=parser, evaluator=evaluator)
tot = TreeofThought(strategy=Strategy.BFS, config=config)
asyncio.run(tot.solve(init_prompt=initial_prompt))