LandGPT / true_pre_2level.py
zhou777's picture
Update true_pre_2level.py
2de6616 verified
import json
import re
from tqdm import tqdm
from lmdeploy import pipeline, TurbomindEngineConfig, GenerationConfig
from lmdeploy.vl import load_image
import csv
# 设置全局变量
# Set global variables
NUM_EXAMPLES = 900 # 要处理的 JSONL 文件的行数(即例子数量)The number of lines (i.e., examples) in the JSONL file to be processed.
QUESTIONS_PER_EXAMPLE = 2 # 每个例子的标准问题数量 The standard number of questions per example
Trans_level = 1 # 定义全局变量,决定是否切换到跨级判别框架,0为不启用跨级判别框架,1为启用 Define a global variable to decide whether to switch to the trans-level discrimination framework. 0 means disabling the trans-level discrimination framework, and 1 means enabling it.
model = '/path/model'
jsonl_file = '/path/val.jsonl'
# tp参数为并行推理卡数量,如果只有单卡推理,请设置为 1
# The tp parameter represents the number of GPUs used for parallel inference. If using a single GPU for inference, please set it to 1.
pipe = pipeline(model, backend_config=TurbomindEngineConfig(session_len=8192, tp=4))
gen_config = GenerationConfig(top_k=40, top_p=0.8, temperature=0.8)
first_level_to_second_levels = {
"Residential Districts": ["High-rise Residential Buildings", "Urban Villages", "Rural Architecture and Farmland"],
"Commercial Zones": ["Business Tower", "Commercial Entertainment", "Office Campus", "Commercial Market", "Shopping Center and Commercial Street"],
"Industrial Land": ["Industrial Park and Factory"],
"Public Services": ["Party and Government Institutions", "Non-profit Public Institutions (Museum; Stadium; Hospital)", "Educational and Research Institutions", "Parks and Squares"],
"Agriculture and Nature": ["Mountain", "Forestland and Grassland", "Water", "Farmland", "Wasteland"],
}
def init_accuracy_dict(levels):
return {level: {'correct': 0, 'total': 0} for level in levels}
first_level_accuracy = init_accuracy_dict(first_level_to_second_levels.keys())
second_level_accuracy = {first_level: init_accuracy_dict(second_levels)
for first_level, second_levels in first_level_to_second_levels.items()}
def extract_level(text, pattern):
match = re.match(pattern, text)
if match:
return match.group(1)
return None
def extract_poi_info(conv_text):
# 使用正则表达式提取POI信息,假设POI信息总是以特定分隔符如“;”分隔。
# Use regular expressions to extract POI information, assuming the POI information is always separated by a specific delimiter such as ";".
poi_info_start = conv_text.find("image'.") + len("image'.")
poi_info_end = conv_text.find(".This image includes the 24-hour pedestrian density")
if poi_info_start >= 0 and poi_info_end > 0:
poi_info = conv_text[poi_info_start:poi_info_end].strip()
return poi_info
return ""
def extract_pedestrian_density(conv_text):
# 使用正则表达式或其他方法定位和提取24小时行人密度数组
# Use regular expressions or other methods to locate and extract the 24-hour pedestrian density array.
density_start = conv_text.find("density data is: [") + len("density data is: [")
density_end = conv_text.find("]", density_start)
if density_start >= len("density data is: [") and density_end > density_start:
density_info = conv_text[density_start:density_end].strip()
return density_info
return ""
def process_conversations(data, session=None):
correct_first = 0
correct_second = 0
error_log = []
image = load_image(data['image'])
conv = data['conversations']
predicted_first_level = None
for i in range(0, len(conv), 2):
human_question = conv[i]['value']
if i == 0: # First question
response = pipe.chat((human_question, image), session=session, gen_config=gen_config)
generated_answer = response.response.text.strip()
first_level = extract_level(conv[i + 1]['value'], r'The FirstLevel is (.+)$')
predicted_first_level = extract_level(generated_answer, r'(?:The FirstLevel is )?(.+)$')
# Update first_level_accuracy
if first_level in first_level_accuracy:
first_level_accuracy[first_level]['total'] += 1
first_level_writer.writerow([data['id'], first_level, predicted_first_level])
if first_level == predicted_first_level:
correct_first += 1
first_level_accuracy[first_level]['correct'] += 1
else:
error_log.append({
'id': data['id'],
'image': data['image'],
'question': human_question,
'model_answer': generated_answer,
'correct_answer': conv[i + 1]['value']
})
else:
print(f"FirstLevel '{first_level}' not found in predefined levels.")
elif i == 2: # Second question
# 提取POI信息 POI
# Extract POI information (Points of Interest)
poi_info = extract_poi_info(conv[i]['value'])
# 提取行人密度信息 People
# Extract pedestrian density information
pedestrian_density = extract_pedestrian_density(conv[i]['value'])
if Trans_level == 0:
current_first_level = predicted_first_level
else:
current_first_level = first_level
if current_first_level in first_level_to_second_levels:
second_levels = first_level_to_second_levels[current_first_level]
# 构建下一个问题,包括POI和行人密度信息
# Construct the next question, including POI and pedestrian density information
human_question = (f"The FirstLevel category of this image is {current_first_level}. "
f"Please select the most likely SecondLevel among {', '.join(second_levels)}. "
"This image contains some POI (Point of Interest) information, "
"which is now provided to you. You can refer to this POI information "
"to make a judgment. "
"If you believe that this image does not belong to any of the above-listed SecondLevel categories, "
"you can also select the SecondLevel category you think this image should belong to based on the POI information and the pedestrian density data below."
"The format of the POI information is: "
f"'POI category'-'the number of occurrences of this category in the image'. {poi_info} "
"Additionally, this image includes the pedestrian density data, "
"which could also help in making a judgment. The density data is: "
f"{pedestrian_density}")
response = pipe.chat(human_question, session=session, gen_config=gen_config)
generated_answer = response.response.text.strip()
second_level = extract_level(conv[i + 1]['value'], r'The SecondLevel is (.+)$')
predicted_second_level = extract_level(generated_answer, r'(?:The SecondLevel is )?(.+)$')
# Update second_level_accuracy
# 确保真实的第一级分类存在于结构中
# Ensure the real first-level category exists in the structure
real_first_level = extract_level(conv[i - 1]['value'], r'The FirstLevel is (.+)$')
if real_first_level in second_level_accuracy:
# 使用真实的一级分类查找对应的二级分类结构
# Use the real first-level category to find the corresponding second-level structure
second_level_data = second_level_accuracy[real_first_level]
# 更新统计总数,确保真实的二级分类存在于结构中
# Update statistics, ensuring the real second-level category exists in the structure
if second_level in second_level_data:
second_level_data[second_level]['total'] += 1
second_level_writer.writerow([data['id'], real_first_level, second_level, predicted_second_level])
# 比较和记录正确性
# Compare and record correctness
if second_level == predicted_second_level:
correct_second += 1
second_level_data[second_level]['correct'] += 1
else:
error_log.append({
'id': data['id'],
'image': data['image'],
'question': human_question,
'model_answer': generated_answer,
'correct_answer': conv[i + 1]['value']
})
else:
print(f"SecondLevel '{second_level}' under Real FirstLevel '{real_first_level}' not found in predefined levels.")
else:
print(f"Real FirstLevel '{real_first_level}' does not exist in the accuracy dictionary.")
session = response
return correct_first, correct_second, error_log
correct_first_total = 0
correct_second_total = 0
error_logs = []
with open(jsonl_file, 'r') as f:
lines = [next(f) for _ in range(NUM_EXAMPLES)] # 只读取前 NUM_EXAMPLES 行
# Only read the first NUM_EXAMPLES lines
with open('first_level_results_true_0_9K.csv', 'w', newline='') as first_level_csv_file, \
open('second_level_results_true_0_9K.csv', 'w', newline='') as second_level_csv_file:
first_level_writer = csv.writer(first_level_csv_file)
second_level_writer = csv.writer(second_level_csv_file)
first_level_writer.writerow(['id', 'expected', 'predicted'])
second_level_writer.writerow(['id', 'first_level', 'expected', 'predicted'])
for line in tqdm(lines, desc="Processing files"):
data = json.loads(line)
correct_first, correct_second, errors = process_conversations(data)
correct_first_total += correct_first
correct_second_total += correct_second
error_logs.extend(errors)
# 计算和打印正确率
# Calculate and print accuracy
total_questions = NUM_EXAMPLES * QUESTIONS_PER_EXAMPLE
first_accuracy = correct_first_total / NUM_EXAMPLES
second_accuracy = correct_second_total / NUM_EXAMPLES
print(f'First question accuracy: {first_accuracy * 100:.2f}%')
print(f'Second question accuracy: {second_accuracy * 100:.2f}%')
print(f'Overall accuracy: {((correct_first_total + correct_second_total) / total_questions) * 100:.2f}%')
# 计算一级分类正确率
# Calculate FirstLevel accuracy
for first_level in first_level_accuracy:
correct = first_level_accuracy[first_level]['correct']
total = first_level_accuracy[first_level]['total']
if total > 0:
accuracy = correct / total
print(f'Accuracy for FirstLevel "{first_level}": {accuracy * 100:.2f}% right/total: {correct} / {total} ')
# 计算二级分类正确率
# Calculate SecondLevel accuracy
for first_level, second_levels in second_level_accuracy.items():
for second_level in second_levels:
correct = second_levels[second_level]['correct']
total = second_levels[second_level]['total']
if total > 0:
accuracy = correct / total
print(f'Accuracy for SecondLevel "{second_level}" under FirstLevel "{first_level}": {accuracy * 100:.2f}% right/total: {correct} / {total}')
# 将错误记录写入日志文件
# Write error logs to a file
with open('error_log_0_9K', 'w') as outfile:
json.dump(error_logs, outfile, indent=4)