Merge pull request #867 from shenchucheng/fix-research-bugs

fix research bugs
This commit is contained in:
geekan 2024-02-07 18:04:28 +08:00 committed by GitHub
commit a54e18f8ca
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 5 additions and 2 deletions

View file

@ -253,7 +253,9 @@ class OpenAILLM(BaseLLM):
def _get_max_tokens(self, messages: list[dict]):
if not self.auto_max_tokens:
return self.config.max_token
return get_max_completion_tokens(messages, self.model, self.config.max_token)
# FIXME
# https://community.openai.com/t/why-is-gpt-3-5-turbo-1106-max-tokens-limited-to-4096/494973/3
return min(get_max_completion_tokens(messages, self.model, self.config.max_token), 4096)
@handle_exception
async def amoderation(self, content: Union[str, list[str]]):

View file

@ -93,7 +93,7 @@ def split_paragraph(paragraph: str, sep: str = ".,", count: int = 2) -> list[str
continue
ret = ["".join(j) for j in _split_by_count(sentences, count)]
return ret
return _split_by_count(paragraph, count)
return list(_split_by_count(paragraph, count))
def decode_unicode_escape(text: str) -> str:

View file

@ -42,6 +42,7 @@ def test_reduce_message_length(msgs, model_name, system_text, reserved, expected
(" ".join("Hello World." for _ in range(1000)), "Prompt: {}", "gpt-3.5-turbo-16k", "System", 3000, 1),
(" ".join("Hello World." for _ in range(4000)), "Prompt: {}", "gpt-4", "System", 2000, 2),
(" ".join("Hello World." for _ in range(8000)), "Prompt: {}", "gpt-4-32k", "System", 4000, 1),
(" ".join("Hello World" for _ in range(8000)), "Prompt: {}", "gpt-3.5-turbo", "System", 1000, 8),
],
)
def test_generate_prompt_chunk(text, prompt_template, model_name, system_text, reserved, expected):