From 379ae511ee5ac58f08bb8b79d59e490240fb2692 Mon Sep 17 00:00:00 2001 From: stellahsr Date: Sun, 14 Jan 2024 17:11:21 +0800 Subject: [PATCH] RUN pre-commit --- examples/tot/creative_writing.py | 22 +++++++++++----------- examples/tot/game24.py | 22 +++++++++++----------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/examples/tot/creative_writing.py b/examples/tot/creative_writing.py index e93d3127d..03a05e6ea 100644 --- a/examples/tot/creative_writing.py +++ b/examples/tot/creative_writing.py @@ -4,6 +4,7 @@ # @Desc : import re +from examples.tot.prompt_templates.creative_writing import cot_prompt, vote_prompt from metagpt.strategy.tot import TreeofThought from metagpt.strategy.tot_schema import ( BaseEvaluator, @@ -11,19 +12,18 @@ from metagpt.strategy.tot_schema import ( Strategy, ThoughtSolverConfig, ) -from examples.tot.prompt_templates.creative_writing import cot_prompt, vote_prompt class TextGenParser(BaseParser): propose_prompt: str = cot_prompt value_prompt: str = vote_prompt - + def __call__(self, input_text: str) -> str: return input_text - + def propose(self, current_state: str, **kwargs) -> str: return self.propose_prompt.format(input=current_state, **kwargs) - + def value(self, input: str = "", **kwargs) -> str: # node_result = self(input) id = kwargs.get("node_id", "0") @@ -33,14 +33,14 @@ class TextGenParser(BaseParser): class TextGenEvaluator(BaseEvaluator): value_map: dict = {"impossible": 0.001, "likely": 1, "sure": 20} # TODO: ad hoc status_map: dict = {val: key for key, val in value_map.items()} - + def __call__(self, evaluation: str, **kwargs) -> float: try: value = 0 node_id = kwargs.get("node_id", "0") pattern = r".*best choice is .*(\d+).*" match = re.match(pattern, evaluation, re.DOTALL) - + if match: vote = int(match.groups()[0]) print(vote) @@ -49,7 +49,7 @@ class TextGenEvaluator(BaseEvaluator): except: value = 0 return value - + def status_verify(self, value): status = False if value in self.status_map: @@ -61,13 +61,13 @@ class TextGenEvaluator(BaseEvaluator): if __name__ == "__main__": import asyncio - + initial_prompt = """It isn't difficult to do a handstand if you just stand on your hands. It caught him off guard that space smelled of seared steak. When she didn’t like a guy who was trying to pick her up, she started using sign language. Each person who knows you has a different perception of who you are.""" - + parser = TextGenParser() evaluator = TextGenEvaluator() - + config = ThoughtSolverConfig(n_generate_sample=3, parser=parser, evaluator=evaluator) - + tot_base = TreeofThought(strategy=Strategy.BFS, config=config) asyncio.run(tot_base.solve(init_prompt=initial_prompt)) diff --git a/examples/tot/game24.py b/examples/tot/game24.py index 62cc6e690..bbf360def 100644 --- a/examples/tot/game24.py +++ b/examples/tot/game24.py @@ -4,6 +4,7 @@ # @Desc : import re +from examples.tot.prompt_templates.game24 import propose_prompt, value_prompt from metagpt.strategy.tot import TreeofThought from metagpt.strategy.tot_schema import ( BaseEvaluator, @@ -11,29 +12,28 @@ from metagpt.strategy.tot_schema import ( Strategy, ThoughtSolverConfig, ) -from examples.tot.prompt_templates.game24 import propose_prompt, value_prompt class Game24Parser(BaseParser): propose_prompt: str = propose_prompt value_prompt: str = value_prompt - + def __call__(self, input_text: str) -> str: last_line = input_text.strip().split("\n")[-1] return last_line.split("left: ")[-1].split(")")[0] - + def propose(self, current_state: str, **kwargs) -> str: return self.propose_prompt.format(input=current_state, **kwargs) - + def value(self, input: str = "", **kwargs) -> str: node_result = self(input) return self.value_prompt.format(input=node_result) class Game24Evaluator(BaseEvaluator): - value_map : dict = {"impossible": 0.001, "likely": 1, "sure": 20} # TODO: ad hoc - status_map : dict = {val: key for key, val in value_map.items()} - + value_map: dict = {"impossible": 0.001, "likely": 1, "sure": 20} # TODO: ad hoc + status_map: dict = {val: key for key, val in value_map.items()} + def __call__(self, evaluation: str, **kwargs) -> float: try: matches = re.findall(r"\b(impossible|sure|likely)\b", evaluation) @@ -41,7 +41,7 @@ class Game24Evaluator(BaseEvaluator): except: value = 0.001 return value - + def status_verify(self, value): status = False if value in self.status_map: @@ -53,12 +53,12 @@ class Game24Evaluator(BaseEvaluator): if __name__ == "__main__": import asyncio - + initial_prompt = """4 5 6 10""" parser = Game24Parser() evaluator = Game24Evaluator() - + config = ThoughtSolverConfig(n_generate_sample=5, parser=parser, evaluator=evaluator) - + tot = TreeofThought(strategy=Strategy.BFS, config=config) asyncio.run(tot.solve(init_prompt=initial_prompt))