refactor action_output and action_node

This commit is contained in:
geekan 2023-12-19 21:24:08 +08:00 committed by better629
parent a75ab7971f
commit d159bfc4e1
7 changed files with 18 additions and 39 deletions

View file

@ -6,7 +6,7 @@
@File : action_node.py
"""
import json
from typing import Dict, Generic, List, Optional, Type, TypeVar
from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar
from pydantic import BaseModel, create_model, root_validator, validator
from tenacity import retry, stop_after_attempt, wait_random_exponential
@ -127,7 +127,7 @@ class ActionNode(Generic[T]):
return self.get_self_mapping()
@classmethod
def create_model_class(cls, class_name: str, mapping: Dict[str, Type]):
def create_model_class(cls, class_name: str, mapping: Dict[str, Tuple[Type, Any]]):
"""基于pydantic v1的模型动态生成用来检验结果类型正确性"""
new_class = create_model(class_name, **mapping)

View file

@ -6,9 +6,7 @@
@File : action_output
"""
from typing import Dict, Type
from pydantic import BaseModel, create_model, root_validator, validator
from pydantic import BaseModel
class ActionOutput:
@ -18,25 +16,3 @@ class ActionOutput:
def __init__(self, content: str, instruct_content: BaseModel):
self.content = content
self.instruct_content = instruct_content
@classmethod
def create_model_class(cls, class_name: str, mapping: Dict[str, Type]):
new_class = create_model(class_name, **mapping)
@validator("*", allow_reuse=True)
def check_name(v, field):
if field.name not in mapping.keys():
raise ValueError(f"Unrecognized block: {field.name}")
return v
@root_validator(pre=True, allow_reuse=True)
def check_missing_fields(values):
required_fields = set(mapping.keys())
missing_fields = required_fields - set(values.keys())
if missing_fields:
raise ValueError(f"Missing fields: {missing_fields}")
return values
new_class.__validator_check_name = classmethod(check_name)
new_class.__root_validator_check_missing_fields = classmethod(check_missing_fields)
return new_class

View file

@ -69,7 +69,7 @@ class WritePRD(Action):
content: Optional[str] = None
llm: BaseGPTAPI = Field(default_factory=LLM)
async def run(self, with_messages, format=CONFIG.prompt_format, *args, **kwargs) -> ActionOutput | Message:
async def run(self, with_messages, schema=CONFIG.prompt_schema, *args, **kwargs) -> ActionOutput | Message:
# Determine which requirement documents need to be rewritten: Use LLM to assess whether new requirements are
# related to the PRD. If they are related, rewrite the PRD.
docs_file_repo = CONFIG.git_repo.new_file_repository(relative_path=DOCS_FILE_REPO)