Merge pull request #14 from iorisa/feature/fixbug_deadloop

fixbug: dead loop
This commit is contained in:
Justin-ZL 2023-09-01 09:53:08 +08:00 committed by GitHub
commit 301dcc27ae
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -276,6 +276,8 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
@staticmethod
def split_texts(text: str, window_size) -> List[str]:
"""Splitting long text into sliding windows text"""
if window_size <= 0:
window_size = OpenAIGPTAPI.DEFAULT_TOKEN_SIZE
total_len = len(text)
if total_len <= window_size:
return [text]
@ -283,18 +285,19 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
padding_size = 20 if window_size > 20 else 0
windows = []
idx = 0
data_len = window_size - padding_size
while idx < total_len:
data_len = window_size - padding_size
if data_len + idx > total_len:
if window_size + idx > total_len: # 不足一个滑窗
windows.append(text[idx:])
break
w = text[idx:data_len]
# 每个窗口少算padding_size自然就可实现滑窗功能, 比如: [1, 2, 3, 4, 5, 6, 7, ....]
# window_size=3, padding_size=1
# [1, 2, 3], [3, 4, 5], [5, 6, 7], ....
# idx=2, | idx=5 | idx=8 | ...
w = text[idx : idx + window_size]
windows.append(w)
idx += data_len
for i in range(len(windows)):
if i + 1 == len(windows):
break
windows[i] += windows[i + 1][0:padding_size]
return windows
@staticmethod
@ -348,6 +351,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
raise openai.error.OpenAIError("Exceeds the maximum retries")
MAX_TRY = 5
DEFAULT_TOKEN_SIZE = 500
if __name__ == "__main__":