Update true_pre_2level.py
Browse files- true_pre_2level.py +21 -15
true_pre_2level.py
CHANGED
|
@@ -9,6 +9,7 @@ import csv
|
|
| 9 |
# Set global variables
|
| 10 |
NUM_EXAMPLES = 900 # 要处理的 JSONL 文件的行数(即例子数量)The number of lines (i.e., examples) in the JSONL file to be processed.
|
| 11 |
QUESTIONS_PER_EXAMPLE = 2 # 每个例子的标准问题数量 The standard number of questions per example
|
|
|
|
| 12 |
|
| 13 |
model = '/path/model'
|
| 14 |
jsonl_file = '/path/val.jsonl'
|
|
@@ -72,18 +73,10 @@ def process_conversations(data, session=None):
|
|
| 72 |
for i in range(0, len(conv), 2):
|
| 73 |
human_question = conv[i]['value']
|
| 74 |
|
| 75 |
-
# if '<image>' in human_question:
|
| 76 |
-
|
| 77 |
-
# else:
|
| 78 |
-
# response = pipe.chat(human_question, session=session, gen_config=gen_config)
|
| 79 |
-
|
| 80 |
-
# generated_answer = response.response.text.strip()
|
| 81 |
-
|
| 82 |
if i == 0: # First question
|
| 83 |
response = pipe.chat((human_question, image), session=session, gen_config=gen_config)
|
| 84 |
generated_answer = response.response.text.strip()
|
| 85 |
first_level = extract_level(conv[i + 1]['value'], r'The FirstLevel is (.+)$')
|
| 86 |
-
# predicted_first_level = extract_level(generated_answer, r'The FirstLevel is (.+)$')
|
| 87 |
predicted_first_level = extract_level(generated_answer, r'(?:The FirstLevel is )?(.+)$')
|
| 88 |
# Update first_level_accuracy
|
| 89 |
if first_level in first_level_accuracy:
|
|
@@ -106,16 +99,22 @@ def process_conversations(data, session=None):
|
|
| 106 |
|
| 107 |
elif i == 2: # Second question
|
| 108 |
# 提取POI信息 POI
|
|
|
|
| 109 |
poi_info = extract_poi_info(conv[i]['value'])
|
| 110 |
# 提取行人密度信息 People
|
|
|
|
| 111 |
pedestrian_density = extract_pedestrian_density(conv[i]['value'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
-
if
|
| 114 |
-
second_levels = first_level_to_second_levels[
|
| 115 |
# 构建下一个问题,包括POI和行人密度信息
|
| 116 |
-
#
|
| 117 |
-
|
| 118 |
-
human_question = (f"The FirstLevel category of this image is {predicted_first_level}. "
|
| 119 |
f"Please select the most likely SecondLevel among {', '.join(second_levels)}. "
|
| 120 |
"This image contains some POI (Point of Interest) information, "
|
| 121 |
"which is now provided to you. You can refer to this POI information "
|
|
@@ -133,22 +132,25 @@ def process_conversations(data, session=None):
|
|
| 133 |
|
| 134 |
second_level = extract_level(conv[i + 1]['value'], r'The SecondLevel is (.+)$')
|
| 135 |
predicted_second_level = extract_level(generated_answer, r'(?:The SecondLevel is )?(.+)$')
|
| 136 |
-
# predicted_second_level = extract_level(generated_answer, r'The SecondLevel is (.+)$')
|
| 137 |
|
| 138 |
# Update second_level_accuracy
|
| 139 |
# 确保真实的第一级分类存在于结构中
|
|
|
|
| 140 |
real_first_level = extract_level(conv[i - 1]['value'], r'The FirstLevel is (.+)$')
|
| 141 |
|
| 142 |
if real_first_level in second_level_accuracy:
|
| 143 |
# 使用真实的一级分类查找对应的二级分类结构
|
|
|
|
| 144 |
second_level_data = second_level_accuracy[real_first_level]
|
| 145 |
|
| 146 |
# 更新统计总数,确保真实的二级分类存在于结构中
|
|
|
|
| 147 |
if second_level in second_level_data:
|
| 148 |
second_level_data[second_level]['total'] += 1
|
| 149 |
second_level_writer.writerow([data['id'], real_first_level, second_level, predicted_second_level])
|
| 150 |
|
| 151 |
# 比较和记录正确性
|
|
|
|
| 152 |
if second_level == predicted_second_level:
|
| 153 |
correct_second += 1
|
| 154 |
second_level_data[second_level]['correct'] += 1
|
|
@@ -172,9 +174,9 @@ correct_first_total = 0
|
|
| 172 |
correct_second_total = 0
|
| 173 |
error_logs = []
|
| 174 |
|
| 175 |
-
|
| 176 |
with open(jsonl_file, 'r') as f:
|
| 177 |
lines = [next(f) for _ in range(NUM_EXAMPLES)] # 只读取前 NUM_EXAMPLES 行
|
|
|
|
| 178 |
|
| 179 |
with open('first_level_results_true_0_9K.csv', 'w', newline='') as first_level_csv_file, \
|
| 180 |
open('second_level_results_true_0_9K.csv', 'w', newline='') as second_level_csv_file:
|
|
@@ -194,6 +196,7 @@ with open('first_level_results_true_0_9K.csv', 'w', newline='') as first_level_c
|
|
| 194 |
error_logs.extend(errors)
|
| 195 |
|
| 196 |
# 计算和打印正确率
|
|
|
|
| 197 |
total_questions = NUM_EXAMPLES * QUESTIONS_PER_EXAMPLE
|
| 198 |
|
| 199 |
first_accuracy = correct_first_total / NUM_EXAMPLES
|
|
@@ -204,6 +207,7 @@ print(f'Second question accuracy: {second_accuracy * 100:.2f}%')
|
|
| 204 |
print(f'Overall accuracy: {((correct_first_total + correct_second_total) / total_questions) * 100:.2f}%')
|
| 205 |
|
| 206 |
# 计算一级分类正确率
|
|
|
|
| 207 |
for first_level in first_level_accuracy:
|
| 208 |
correct = first_level_accuracy[first_level]['correct']
|
| 209 |
total = first_level_accuracy[first_level]['total']
|
|
@@ -212,6 +216,7 @@ for first_level in first_level_accuracy:
|
|
| 212 |
print(f'Accuracy for FirstLevel "{first_level}": {accuracy * 100:.2f}% right/total: {correct} / {total} ')
|
| 213 |
|
| 214 |
# 计算二级分类正确率
|
|
|
|
| 215 |
for first_level, second_levels in second_level_accuracy.items():
|
| 216 |
for second_level in second_levels:
|
| 217 |
correct = second_levels[second_level]['correct']
|
|
@@ -221,5 +226,6 @@ for first_level, second_levels in second_level_accuracy.items():
|
|
| 221 |
print(f'Accuracy for SecondLevel "{second_level}" under FirstLevel "{first_level}": {accuracy * 100:.2f}% right/total: {correct} / {total}')
|
| 222 |
|
| 223 |
# 将错误记录写入日志文件
|
|
|
|
| 224 |
with open('error_log_0_9K', 'w') as outfile:
|
| 225 |
json.dump(error_logs, outfile, indent=4)
|
|
|
|
| 9 |
# Set global variables
|
| 10 |
NUM_EXAMPLES = 900 # 要处理的 JSONL 文件的行数(即例子数量)The number of lines (i.e., examples) in the JSONL file to be processed.
|
| 11 |
QUESTIONS_PER_EXAMPLE = 2 # 每个例子的标准问题数量 The standard number of questions per example
|
| 12 |
+
Trans_level = 1 # 定义全局变量,决定是否切换到跨级判别框架,0为不启用跨级判别框架,1为启用 Define a global variable to decide whether to switch to the trans-level discrimination framework. 0 means disabling the trans-level discrimination framework, and 1 means enabling it.
|
| 13 |
|
| 14 |
model = '/path/model'
|
| 15 |
jsonl_file = '/path/val.jsonl'
|
|
|
|
| 73 |
for i in range(0, len(conv), 2):
|
| 74 |
human_question = conv[i]['value']
|
| 75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
if i == 0: # First question
|
| 77 |
response = pipe.chat((human_question, image), session=session, gen_config=gen_config)
|
| 78 |
generated_answer = response.response.text.strip()
|
| 79 |
first_level = extract_level(conv[i + 1]['value'], r'The FirstLevel is (.+)$')
|
|
|
|
| 80 |
predicted_first_level = extract_level(generated_answer, r'(?:The FirstLevel is )?(.+)$')
|
| 81 |
# Update first_level_accuracy
|
| 82 |
if first_level in first_level_accuracy:
|
|
|
|
| 99 |
|
| 100 |
elif i == 2: # Second question
|
| 101 |
# 提取POI信息 POI
|
| 102 |
+
# Extract POI information (Points of Interest)
|
| 103 |
poi_info = extract_poi_info(conv[i]['value'])
|
| 104 |
# 提取行人密度信息 People
|
| 105 |
+
# Extract pedestrian density information
|
| 106 |
pedestrian_density = extract_pedestrian_density(conv[i]['value'])
|
| 107 |
+
|
| 108 |
+
if Trans_level == 0:
|
| 109 |
+
current_first_level = predicted_first_level
|
| 110 |
+
else:
|
| 111 |
+
current_first_level = first_level
|
| 112 |
|
| 113 |
+
if current_first_level in first_level_to_second_levels:
|
| 114 |
+
second_levels = first_level_to_second_levels[current_first_level]
|
| 115 |
# 构建下一个问题,包括POI和行人密度信息
|
| 116 |
+
# Construct the next question, including POI and pedestrian density information
|
| 117 |
+
human_question = (f"The FirstLevel category of this image is {current_first_level}. "
|
|
|
|
| 118 |
f"Please select the most likely SecondLevel among {', '.join(second_levels)}. "
|
| 119 |
"This image contains some POI (Point of Interest) information, "
|
| 120 |
"which is now provided to you. You can refer to this POI information "
|
|
|
|
| 132 |
|
| 133 |
second_level = extract_level(conv[i + 1]['value'], r'The SecondLevel is (.+)$')
|
| 134 |
predicted_second_level = extract_level(generated_answer, r'(?:The SecondLevel is )?(.+)$')
|
|
|
|
| 135 |
|
| 136 |
# Update second_level_accuracy
|
| 137 |
# 确保真实的第一级分类存在于结构中
|
| 138 |
+
# Ensure the real first-level category exists in the structure
|
| 139 |
real_first_level = extract_level(conv[i - 1]['value'], r'The FirstLevel is (.+)$')
|
| 140 |
|
| 141 |
if real_first_level in second_level_accuracy:
|
| 142 |
# 使用真实的一级分类查找对应的二级分类结构
|
| 143 |
+
# Use the real first-level category to find the corresponding second-level structure
|
| 144 |
second_level_data = second_level_accuracy[real_first_level]
|
| 145 |
|
| 146 |
# 更新统计总数,确保真实的二级分类存在于结构中
|
| 147 |
+
# Update statistics, ensuring the real second-level category exists in the structure
|
| 148 |
if second_level in second_level_data:
|
| 149 |
second_level_data[second_level]['total'] += 1
|
| 150 |
second_level_writer.writerow([data['id'], real_first_level, second_level, predicted_second_level])
|
| 151 |
|
| 152 |
# 比较和记录正确性
|
| 153 |
+
# Compare and record correctness
|
| 154 |
if second_level == predicted_second_level:
|
| 155 |
correct_second += 1
|
| 156 |
second_level_data[second_level]['correct'] += 1
|
|
|
|
| 174 |
correct_second_total = 0
|
| 175 |
error_logs = []
|
| 176 |
|
|
|
|
| 177 |
with open(jsonl_file, 'r') as f:
|
| 178 |
lines = [next(f) for _ in range(NUM_EXAMPLES)] # 只读取前 NUM_EXAMPLES 行
|
| 179 |
+
# Only read the first NUM_EXAMPLES lines
|
| 180 |
|
| 181 |
with open('first_level_results_true_0_9K.csv', 'w', newline='') as first_level_csv_file, \
|
| 182 |
open('second_level_results_true_0_9K.csv', 'w', newline='') as second_level_csv_file:
|
|
|
|
| 196 |
error_logs.extend(errors)
|
| 197 |
|
| 198 |
# 计算和打印正确率
|
| 199 |
+
# Calculate and print accuracy
|
| 200 |
total_questions = NUM_EXAMPLES * QUESTIONS_PER_EXAMPLE
|
| 201 |
|
| 202 |
first_accuracy = correct_first_total / NUM_EXAMPLES
|
|
|
|
| 207 |
print(f'Overall accuracy: {((correct_first_total + correct_second_total) / total_questions) * 100:.2f}%')
|
| 208 |
|
| 209 |
# 计算一级分类正确率
|
| 210 |
+
# Calculate FirstLevel accuracy
|
| 211 |
for first_level in first_level_accuracy:
|
| 212 |
correct = first_level_accuracy[first_level]['correct']
|
| 213 |
total = first_level_accuracy[first_level]['total']
|
|
|
|
| 216 |
print(f'Accuracy for FirstLevel "{first_level}": {accuracy * 100:.2f}% right/total: {correct} / {total} ')
|
| 217 |
|
| 218 |
# 计算二级分类正确率
|
| 219 |
+
# Calculate SecondLevel accuracy
|
| 220 |
for first_level, second_levels in second_level_accuracy.items():
|
| 221 |
for second_level in second_levels:
|
| 222 |
correct = second_levels[second_level]['correct']
|
|
|
|
| 226 |
print(f'Accuracy for SecondLevel "{second_level}" under FirstLevel "{first_level}": {accuracy * 100:.2f}% right/total: {correct} / {total}')
|
| 227 |
|
| 228 |
# 将错误记录写入日志文件
|
| 229 |
+
# Write error logs to a file
|
| 230 |
with open('error_log_0_9K', 'w') as outfile:
|
| 231 |
json.dump(error_logs, outfile, indent=4)
|