-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathLLMHelper.py
More file actions
267 lines (225 loc) · 10.2 KB
/
LLMHelper.py
File metadata and controls
267 lines (225 loc) · 10.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
import sys
import os
from openai import OpenAI
import openai
import os
import time
import json
from datetime import datetime
import re
import os
import shutil
from pathlib import Path
import json
# 请确保您已将 API Key 存储在环境变量 ARK_API_KEY 中
# 初始化Openai客户端,从环境变量中读取您的API Key
client = OpenAI(
# 此为默认路径,您可根据业务所在地域进行配置
base_url="https://ark.cn-beijing.volces.com/api/v3",
# 从环境变量中获取您的 API Key
api_key=os.environ.get("ARK_API_KEY"),
)
def log_error(message):
"""记录错误信息到日志文件"""
with open(error_log, "a", encoding="utf-8") as f:
f.write(f"[{datetime.now()}] {message}\n")
def llm_call_model(system_prompt, role_prompt, max_retries=100):
# 配置文件路径
OUTPUT_DIR = "doubao_output"
os.makedirs(OUTPUT_DIR, exist_ok=True)
# 生成唯一文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_file = os.path.join(OUTPUT_DIR, f"doubao_response_{timestamp}.txt")
error_log = os.path.join(OUTPUT_DIR, f"error_log_{timestamp}.txt")
"""调用豆包模型并处理流式响应"""
for attempt in range(max_retries):
try:
# 创建并打开输出文件
with open(output_file, "w", encoding="utf-8") as f:
f.write(f"===== 豆包API调用记录 - {timestamp} =====\n")
f.write(f"提示词: {role_prompt}\n\n")
f.write("===== 响应内容 =====\n")
# 调用豆包API并获取流式响应
response = client.chat.completions.create(
model="doubao-1-5-thinking-pro-250415",
messages=[{"role": "system", "content": system_prompt},
{"role": "user", "content": role_prompt}],
stream=True
)
full_response = ""
for chunk in response:
# 处理流式响应
if chunk.choices and chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
print(content, end="", flush=True)
f.write(content)
full_response += content
# 检查是否有完成原因
if chunk.choices and chunk.choices[0].finish_reason:
finish_reason = chunk.choices[0].finish_reason
print(f"\n\n===== 响应结束 - 原因: {finish_reason} =====")
f.write(f"\n\n===== 响应结束 - 原因: {finish_reason} =====")
# 记录完整响应信息
metadata = {
"timestamp": timestamp,
"model": "doubao-1-5-thinking-pro-250415",
# "prompt": role_prompt,
"finish_reason": finish_reason if 'finish_reason' in locals() else "unknown",
"response_length": len(full_response),
"output_file": output_file,
"error_log": error_log if os.path.exists(error_log) else None
}
metadata_file = os.path.join(OUTPUT_DIR, f"metadata_{timestamp}.json")
with open(metadata_file, "w", encoding="utf-8") as mf:
json.dump(metadata, mf, ensure_ascii=False, indent=2)
print(f"\n完整输出已保存到: {output_file}")
return full_response
except openai.AuthenticationError as e:
error_msg = f"认证失败: {str(e)}"
log_error(error_msg)
print(f"\n错误: {error_msg}")
# 认证错误不重试
break
except openai.RateLimitError as e:
error_msg = f"请求频率超限: {str(e)},将在{attempt+1}秒后重试"
log_error(error_msg)
print(f"\n错误: {error_msg}")
time.sleep(attempt + 1) # 指数退避重试
except openai.ServiceUnavailableError as e:
error_msg = f"服务不可用: {str(e)},将在{attempt+1}秒后重试"
log_error(error_msg)
print(f"\n错误: {error_msg}")
time.sleep(attempt + 1) # 指数退避重试
except openai.OpenAIError as e:
error_msg = f"OpenAI API错误: {str(e)},将在{attempt+1}秒后重试"
log_error(error_msg)
print(f"\n错误: {error_msg}")
time.sleep(attempt + 1) # 指数退避重试
except Exception as e:
error_msg = f"未知错误: {str(e)}"
log_error(error_msg)
print(f"\n错误: {error_msg}")
# 未知错误不重试
break
# 如果所有重试都失败
if not os.path.exists(output_file) or os.path.getsize(output_file) == 0:
# 创建一个空的错误报告文件
with open(output_file, "w", encoding="utf-8") as f:
f.write(f"===== 豆包API调用失败 - {timestamp} =====\n")
f.write(f"提示词: {prompt}\n\n")
f.write(f"错误: 所有重试均失败,详情请查看错误日志: {error_log}")
print(f"\n警告: 所有重试均失败,已创建空的报告文件: {output_file}")
print(f"错误详情请查看日志文件: {error_log}")
return None
def clean_text_start(text, encoding='utf-8'):
"""
清理文本开头的乱码字符
参数:
text (str): 需要清理的文本
encoding (str): 预期的文本编码,默认为'utf-8'
返回:
str: 清理后的文本
"""
# 常见乱码字符模式(可根据实际情况扩展)
garbage_pattern = r'^[^\w\s\u4e00-\u9fff]+'
# 尝试检测并修复BOM头
if text.startswith('\ufeff'):
text = text[1:]
# 移除开头的乱码字符
cleaned_text = re.sub(garbage_pattern, '', text)
# 检测是否有行号标记并移除(如"01 "、"??01"等)
line_number_pattern = r'^(\?+|\d+\s*[-\—]\s*)*'
cleaned_text = re.sub(line_number_pattern, '', cleaned_text)
# 移除多余的空白行
cleaned_text = cleaned_text.lstrip('\n\r')
return cleaned_text
def has_chinese_char(s):
# 使用正则表达式
pattern = re.compile(r'[\u4e00-\u9fff]')
return bool(pattern.search(s))
def estimate_tokens(text):
"""
基于文本类型估算token数量
英文: 平均4个字符=1个token
中文: 平均1个汉字=1个token
符号: 平均2个符号=1个token
"""
# 中文字符范围
chinese_pattern = re.compile(r'[\u4e00-\u9fff]')
# 英文字符范围
english_pattern = re.compile(r'[a-zA-Z]')
# 符号范围 (排除中文和英文)
symbol_pattern = re.compile(r'[^\u4e00-\u9fff^a-zA-Z]')
chinese_count = len(chinese_pattern.findall(text))
english_count = len(english_pattern.findall(text))
symbol_count = len(symbol_pattern.findall(text))
# 估算token数量
estimated_tokens = (
chinese_count * 1.0 # 中文: 1字符=1token
+ english_count / 4.0 # 英文: 4字符=1token
+ symbol_count / 2.0 # 符号: 2符号=1token
)
return int(estimated_tokens)
def llm_process_long_text(system_prompt, long_role_prompt, max_tokens=3000, max_retries=100):
"""
处理长文本输入,智能合并多个段落以充分利用token限制
:param system_prompt: 系统提示词
:param long_role_prompt: 长用户提示词
:param max_tokens: 模型输入的最大token数量
:param model_name: 使用的模型名称
:param max_retries: 最大重试次数
:return: 完整的响应文本
"""
# 估算系统提示词的token数量
system_tokens = estimate_tokens(system_prompt)
# 可用于用户提示词的token数量
# available_tokens = max_tokens - system_tokens
available_tokens = max_tokens
# 按段落分割文本(保留空行作为段落分隔符)
paragraphs = re.split(r'(\n\s*\n)', long_role_prompt)
full_response = ""
current_batch = []
current_batch_tokens = 0
for i, para in enumerate(paragraphs):
# 跳过空段落
if not para.strip():
continue
# 估算当前段落的token数量
para_tokens = estimate_tokens(para)
# 如果当前段落单独就超过了最大限制
if para_tokens > available_tokens:
print(f"警告: 段落长度 {para_tokens} tokens 超过可用token数量 {available_tokens},此段落将被单独处理")
# 先处理当前批次
if current_batch:
batch_text = ''.join(current_batch)
response = llm_call_model(system_prompt, batch_text, max_retries)
if response:
full_response += response
current_batch = []
current_batch_tokens = 0
# 单独处理超长段落(可能会失败,但至少不会影响其他段落)
response = llm_call_model(system_prompt, para, max_retries)
if response:
full_response += response
continue
# 如果添加当前段落不会超过限制,则添加到当前批次
if current_batch_tokens + para_tokens <= available_tokens:
current_batch.append(para)
current_batch_tokens += para_tokens
else:
# 当前批次已满,处理并开始新批次
if current_batch:
batch_text = ''.join(current_batch)
response = llm_call_model(system_prompt, batch_text, max_retries)
if response:
full_response += response
# 开始新批次
current_batch = [para]
current_batch_tokens = para_tokens
# 处理最后一个批次
if current_batch:
batch_text = ''.join(current_batch)
response = llm_call_model(system_prompt, batch_text, max_retries)
if response:
full_response += response
return full_response