Skip to content

Instantly share code, notes, and snippets.

@MarchLiu
Created November 29, 2025 06:17
Show Gist options
  • Select an option

  • Save MarchLiu/b1a4e1504a5b1d4ee97a3a92f56166a3 to your computer and use it in GitHub Desktop.

Select an option

Save MarchLiu/b1a4e1504a5b1d4ee97a3a92f56166a3 to your computer and use it in GitHub Desktop.
# This is a plant face rec machine
# The class PlantFaceRecognitionSystem can be imported by other programs
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # 减少TensorFlow日志输出
# 尝试不同的导入方式
try:
import tensorflow as tf
from tensorflow.keras.models import load_model, Model, model_from_json
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
TENSORFLOW_AVAILABLE = True
except ImportError as e:
print(f"TensorFlow导入错误: {e}")
# 尝试备用导入
try:
import tensorflow as tf
from tensorflow import keras
from keras.models import load_model, Model, model_from_json
from keras.applications import MobileNetV2
from keras.layers import GlobalAveragePooling2D, Dense
from keras.preprocessing import image
from keras.applications.mobilenet_v2 import preprocess_input
except ImportError:
print("所有TensorFlow导入都失败")
exit(1)
import json
import numpy as np
import pandas as pd
from datetime import datetime
from sklearn.metrics.pairwise import cosine_similarity
#import tensorflow as tf
#from tensorflow.keras.models import load_model, Model, model_from_json
#from tensorflow.keras.applications import MobileNetV2
#from tensorflow.keras.layers import GlobalAveragePooling2D, Dense
#from tensorflow.keras.preprocessing import image
#from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
import warnings
import shutil
import hashlib
import base64
import random
warnings.filterwarnings('ignore')
# 设置随机种子以确保可重复性
def set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)
set_seed(42) # 在程序开始处设置随机种子
# 设置工作目录
os.chdir(".../working/dir")
# This class can be imported by other codes, especially for batch test.
class PlantFaceRecognitionSystem:
def __init__(self, database_path="db/plant_faces.json", model_path="db/plant_model.h5"):
# 设置图像尺寸
self.img_height = 224
self.img_width = 224
# 确保数据库目录存在
os.makedirs(os.path.dirname(database_path), exist_ok=True)
os.makedirs(os.path.dirname(model_path), exist_ok=True)
# 初始化数据库和统计信息
self.database_path = database_path
self.model_path = model_path
self.database = {}
self.stats_path = "db/recognition_stats.csv"
self.stats = pd.DataFrame()
self.last_recognition_images = []
self.last_recognition_index = None
self.last_recognition_result = None
self.enrolled_file_hashes = set()
# 先初始化feature_extractor为None
self.feature_extractor = None
# 加载或创建特征提取模型
self.load_or_create_model(model_path)
# 加载数据库和其他数据
self.load_database()
self.load_stats()
self.load_enrolled_hashes()
# 测试模型是否正常工作
self.test_model_functionality()
def get_species_group_count(self, plant_id):
"""计算同种植物的组数(根据植物名称)"""
if '#' not in plant_id:
# 如果没有#符号,说明只有一种植物
return 1
species_name = plant_id.split('#')[0]
group_count = 0
for existing_id in self.database.keys():
if existing_id.startswith(species_name + '#'):
group_count += 1
return group_count
def load_or_create_model(self, model_path):
"""加载现有模型或创建新的特征提取模型"""
# 检查模型文件是否存在
json_path = model_path.replace('.h5', '.json')
if os.path.exists(model_path) and os.path.exists(json_path):
print(f"加载现有模型: {model_path}")
try:
# 设置相同的随机种子以确保一致性
set_seed(42)
# 加载模型架构
with open(json_path, 'r') as json_file:
model_json = json_file.read()
model = model_from_json(model_json)
# 加载模型权重
model.load_weights(model_path)
print("模型加载成功")
self.feature_extractor = model
return
except Exception as e:
print(f"加载模型失败: {e}")
print("将创建新的特征提取模型")
else:
print("模型文件不存在,创建新的特征提取模型...")
# 使用预训练的MobileNetV2作为基础
try:
# 设置相同的随机种子以确保一致性
set_seed(42)
base_model = MobileNetV2(
input_shape=(self.img_height, self.img_width, 3),
include_top=False,
weights='imagenet'
)
base_model.trainable = False # 冻结基础模型权重
# 添加全局平均池化层
inputs = tf.keras.Input(shape=(self.img_height, self.img_width, 3))
x = base_model(inputs, training=False) # 重要:设置training=False
x = GlobalAveragePooling2D()(x)
model = Model(inputs, x)
self.feature_extractor = model
print("特征提取模型创建成功")
# 自动保存新创建的模型
self.save_model()
except Exception as e:
print(f"创建特征提取模型失败: {e}")
# 返回一个简单的模型作为备用
self.create_fallback_model()
def create_fallback_model(self):
"""创建备用模型(当主模型失败时使用)"""
print("创建备用模型...")
try:
# 设置相同的随机种子以确保一致性
set_seed(42)
# 创建一个简单的CNN模型作为备用
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, (3, 3), activation='relu',
input_shape=(self.img_height, self.img_width, 3),
kernel_initializer='glorot_uniform'),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Conv2D(64, (3, 3), activation='relu',
kernel_initializer='glorot_uniform'),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Conv2D(64, (3, 3), activation='relu',
kernel_initializer='glorot_uniform'),
tf.keras.layers.GlobalAveragePooling2D(),
])
self.feature_extractor = model
print("备用模型创建成功")
# 自动保存备用模型
self.save_model()
except Exception as e:
print(f"创建备用模型失败: {e}")
# 返回一个最简单的模型
model = tf.keras.Sequential([
tf.keras.layers.GlobalAveragePooling2D(input_shape=(self.img_height, self.img_width, 3)),
])
self.feature_extractor = model
# 自动保存最简单的模型
self.save_model()
def save_model(self, model_path=None):
"""保存特征提取模型"""
if model_path is None:
model_path = self.model_path
# 检查feature_extractor是否存在
if self.feature_extractor is None:
print("错误: feature_extractor 尚未初始化,无法保存模型")
return False
try:
# 确保目录存在
os.makedirs(os.path.dirname(model_path), exist_ok=True)
print(f"正在保存模型到: {model_path}")
print(f"模型类型: {type(self.feature_extractor)}")
print(f"模型层数: {len(self.feature_extractor.layers)}")
# 尝试保存完整模型
try:
self.feature_extractor.save(model_path)
print(f"完整模型已保存到 {model_path}")
# 同时保存模型架构为JSON
json_path = model_path.replace('.h5', '.json')
with open(json_path, 'w') as json_file:
json_file.write(self.feature_extractor.to_json())
print(f"模型架构已保存到 {json_path}")
except Exception as e:
print(f"保存完整模型失败: {e}")
return False
# 检查文件是否确实创建
if os.path.exists(model_path):
file_size = os.path.getsize(model_path)
print(f"模型文件已创建,大小: {file_size} 字节")
return True
else:
print("错误: 模型文件未创建")
return False
except Exception as e:
print(f"保存模型时出错: {e}")
return False
# 其他方法保持不变...
def extract_features(self, image_paths, check_duplicates=True):
"""从一组图像中提取特征"""
features = []
valid_paths = []
for img_path in image_paths:
if not os.path.exists(img_path):
print(f"警告: 图像不存在 {img_path}")
continue
# 检查文件是否已录入过(仅在需要时)
if check_duplicates:
file_hash = self.calculate_file_hash(img_path)
if file_hash in self.enrolled_file_hashes:
print(f"跳过已录入图像: {os.path.basename(img_path)}")
continue
try:
# 加载和预处理图像
img = image.load_img(img_path, target_size=(self.img_height, self.img_width))
img_array = image.img_to_array(img)
img_array = np.expand_dims(img_array, axis=0)
# 使用MobileNetV2的预处理函数
img_array = preprocess_input(img_array)
# 提取特征 - 确保确定性操作
feature = self.feature_extractor.predict(img_array, verbose=0)
# 对特征向量进行归一化,提高区分度
feature = feature.flatten()
feature_norm = np.linalg.norm(feature)
# 检查特征是否全为零
if feature_norm < 1e-10: # 接近零
print(f"警告: 图像 {os.path.basename(img_path)} 的特征范数接近零: {feature_norm}")
print("可能的原因: 模型未正确工作或图像预处理有问题")
# 跳过这个特征,但继续处理其他图像
continue
feature = feature / feature_norm # L2归一化
features.append(feature)
valid_paths.append(img_path)
except Exception as e:
print(f"处理图像 {img_path} 时出错: {e}")
if not features:
raise ValueError("没有成功提取任何图像特征")
return np.array(features), valid_paths
def test_model_functionality(self):
"""测试模型功能是否正常"""
print("测试模型功能...")
try:
# 创建一个测试图像(全零图像)
test_image = np.zeros((self.img_height, self.img_width, 3), dtype=np.float32)
test_image = np.expand_dims(test_image, axis=0)
test_image = preprocess_input(test_image)
# 提取特征
test_features = self.feature_extractor.predict(test_image, verbose=0)
test_features = test_features.flatten()
# 检查特征是否全为零
if np.all(test_features == 0):
print("警告: 模型提取的特征全为零!")
print("可能的原因: 模型未正确初始化或权重未加载")
else:
print(f"模型测试通过,特征范数: {np.linalg.norm(test_features):.6f}")
print(f"特征前5个值: {test_features[:5]}")
except Exception as e:
print(f"模型测试失败: {e}")
def load_enrolled_hashes(self):
"""加载已录入文件的哈希值"""
hash_file = "db/enrolled_hashes.txt"
if os.path.exists(hash_file):
try:
with open(hash_file, 'r') as f:
self.enrolled_file_hashes = set(line.strip() for line in f)
print(f"已加载 {len(self.enrolled_file_hashes)} 个已录入文件的哈希值")
except Exception as e:
print(f"加载哈希文件时出错: {e}")
self.enrolled_file_hashes = set()
def save_enrolled_hashes(self):
"""保存已录入文件的哈希值"""
hash_file = "db/enrolled_hashes.txt"
try:
with open(hash_file, 'w') as f:
for file_hash in self.enrolled_file_hashes:
f.write(file_hash + '\n')
except Exception as e:
print(f"保存哈希文件时出错: {e}")
def calculate_file_hash(self, file_path):
"""计算文件的MD5哈希"""
hash_md5 = hashlib.md5()
try:
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
except Exception as e:
print(f"计算文件哈希时出错 {file_path}: {e}")
return None
def load_database(self):
"""加载植物脸数据库"""
if os.path.exists(self.database_path):
try:
with open(self.database_path, 'r') as f:
content = f.read().strip()
if not content:
print("数据库文件为空,创建新的空数据库")
self.database = {}
return
self.database = json.loads(content)
if not self.validate_database_structure():
print("数据库结构验证失败,将尝试修复")
self.try_repair_database()
return
# 处理每个植物的数据
for plant_id in list(self.database.keys()):
if not plant_id or plant_id.strip() == "":
print(f"警告: 发现空的植物ID,已删除")
del self.database[plant_id]
continue
try:
self.database[plant_id]["image_count"] = int(self.database[plant_id]["image_count"])
if "features" in self.database[plant_id]:
if isinstance(self.database[plant_id]["features"][0], str):
self.database[plant_id]["features"] = [
np.frombuffer(base64.b64decode(feature_str), dtype=np.float32).tolist()
for feature_str in self.database[plant_id]["features"]
]
else:
self.database[plant_id]["features"] = [
[float(x) for x in feature]
for feature in self.database[plant_id]["features"]
]
except Exception as e:
print(f"处理植物 {plant_id} 的数据时出错: {e}")
del self.database[plant_id]
print(f"已加载数据库,包含 {len(self.database)} 个植物")
except Exception as e:
print(f"加载数据库时出错: {e}")
self.try_repair_database()
else:
self.database = {}
print("创建新的空数据库")
if self.database:
self.save_database()
def validate_database_structure(self):
"""验证数据库结构是否正确"""
if not isinstance(self.database, dict):
return False
for plant_id, plant_data in self.database.items():
if not isinstance(plant_data, dict):
return False
required_keys = ["features", "enrollment_date", "update_history", "image_count"]
for key in required_keys:
if key not in plant_data:
return False
return True
# 添加一个检查模型文件的命令
def check_model_files(self):
"""检查模型文件状态"""
print("=== 模型文件检查 ===")
# 检查主模型文件
if os.path.exists(self.model_path):
size = os.path.getsize(self.model_path)
print(f"主模型文件: {self.model_path} (大小: {size} 字节)")
else:
print(f"主模型文件: {self.model_path} (不存在)")
# 检查可能的其他模型文件
possible_files = [
self.model_path.replace('.h5', '_architecture.json'),
self.model_path.replace('.h5', '_weights.h5'),
]
for file_path in possible_files:
if os.path.exists(file_path):
size = os.path.getsize(file_path)
print(f"辅助文件: {file_path} (大小: {size} 字节)")
print("=== 模型信息 ===")
print(f"模型层数: {len(self.feature_extractor.layers)}")
print(f"输入形状: {self.feature_extractor.input_shape}")
print(f"输出形状: {self.feature_extractor.output_shape}")
def try_repair_database(self):
"""尝试修复损坏的数据库"""
try:
# 备份原文件
if os.path.exists(self.database_path):
backup_path = self.database_path + ".backup"
shutil.copy2(self.database_path, backup_path)
print(f"已创建数据库备份: {backup_path}")
# 尝试逐行读取JSON文件
with open(self.database_path, 'r') as f:
content = f.read()
# 尝试修复常见的JSON格式错误
# 1. 删除尾随逗号
content = content.replace(',\n}', '\n}').replace(',\n]', '\n]')
# 2. 尝试解析修复后的内容
self.database = json.loads(content)
print("数据库修复成功")
# 保存修复后的数据库
self.save_database()
except Exception as e:
print(f"数据库修复失败: {e}")
print("将创建新的空数据库")
self.database = {}
# 不立即保存,等待有有效数据时再保存
def delete_plant(self, plant_id):
"""删除指定的植物条目"""
if plant_id not in self.database:
return False, f"植物 {plant_id} 不存在"
# 确认删除
confirm = input(f"确定要删除植物 {plant_id} 吗?此操作不可撤销!(y/n): ").strip().lower()
if confirm != 'y':
return False, "取消删除"
# 从数据库中删除
del self.database[plant_id]
# 保存数据库
self.save_database()
return True, f"植物 {plant_id} 已删除"
def export_database_info(self):
"""导出数据库信息到文本文件"""
try:
export_path = "db/database_info.txt"
with open(export_path, 'w') as f:
f.write("=== 植物刷脸机数据库信息 ===\n\n")
f.write(f"导出时间: {datetime.now().isoformat()}\n")
f.write(f"植物数量: {len(self.database)}\n\n")
# 添加同种植物组数统计
species_groups = {}
for plant_id in self.database.keys():
if '#' in plant_id:
species = plant_id.split('#')[0]
else:
species = plant_id
if species not in species_groups:
species_groups[species] = 0
species_groups[species] += 1
f.write("=== 同种植物组数统计 ===\n")
for species, count in species_groups.items():
f.write(f"{species}: {count} 组\n")
f.write("\n")
for plant_id, data in self.database.items():
f.write(f"植物编号: {plant_id}\n")
f.write(f" 录入日期: {data['enrollment_date']}\n")
f.write(f" 图像数量: {data['image_count']}\n")
f.write(f" 特征向量数量: {len(data['features'])}\n")
f.write(f" 更新历史: {len(data['update_history'])} 次\n\n")
print(f"数据库信息已导出到 {export_path}")
return True
except Exception as e:
print(f"导出数据库信息时出错: {e}")
return False
# 其他方法保持不变...
def save_database(self):
"""保存植物脸数据库"""
try:
# 确保所有数值都是Python原生类型
serializable_database = {}
for plant_id, plant_data in self.database.items():
# 修复关键问题:使用Base64编码保存特征向量,避免精度损失
encoded_features = [
base64.b64encode(np.array(feature, dtype=np.float32).tobytes()).decode('utf-8')
for feature in plant_data["features"]
]
serializable_database[plant_id] = {
"features": encoded_features,
"enrollment_date": plant_data["enrollment_date"],
"update_history": plant_data["update_history"],
"image_count": int(plant_data["image_count"]) # 转换为Python int
}
with open(self.database_path, 'w') as f:
json.dump(serializable_database, f, indent=4)
print(f"数据库已保存到 {self.database_path}")
except Exception as e:
print(f"保存数据库时出错: {e}")
def load_stats(self):
"""加载识别统计信息"""
if os.path.exists(self.stats_path):
try:
self.stats = pd.read_csv(self.stats_path, index_col=0)
# 确保列存在
for col in ['true_id', 'correct', 'database_size', 'plant_image_count', 'species_group_count']:
if col not in self.stats.columns:
self.stats[col] = None
print(f"已加载统计信息,包含 {len(self.stats)} 条记录")
except Exception as e:
print(f"加载统计信息时出错: {e}")
self.stats = pd.DataFrame(columns=[
'plant_id', 'recognition_time', 'num_images',
'recognized_id', 'confidence', 'admin_feedback',
'true_id', 'correct', 'database_size', 'plant_image_count', 'species_group_count'
])
else:
self.stats = pd.DataFrame(columns=[
'plant_id', 'recognition_time', 'num_images',
'recognized_id', 'confidence', 'admin_feedback',
'true_id', 'correct', 'database_size', 'plant_image_count', 'species_group_count'
])
print("创建新的统计表格")
def save_stats(self):
"""保存识别统计信息"""
try:
# 确保所有数值都是Python原生类型
self.stats = self.stats.applymap(
lambda x: float(x) if isinstance(x, (np.integer, np.floating)) else x
)
self.stats.to_csv(self.stats_path)
print(f"统计信息已保存到 {self.stats_path}")
except Exception as e:
print(f"保存统计信息时出错: {e}")
def enroll_plant(self, plant_id, image_paths):
"""录入新的植物脸"""
if len(image_paths) < 1:
return False, "需要至少1张图像进行录入"
try:
# 提取特征(检查重复)
features, valid_paths = self.extract_features(image_paths, check_duplicates=True)
if len(valid_paths) < 1:
return False, "没有有效的图像可用于录入"
# 计算平均特征向量
avg_feature = np.mean(features, axis=0)
# 存储到数据库 - 修复关键错误
if plant_id not in self.database:
self.database[plant_id] = {
"features": [],
"enrollment_date": datetime.now().isoformat(),
"update_history": [],
"image_count": 0
}
# 正确追加新特征(不是覆盖)
self.database[plant_id]["features"].append(avg_feature.tolist())
self.database[plant_id]["image_count"] += len(valid_paths)
# 更新历史记录
update_type = "enrollment" if len(self.database[plant_id]["features"]) == 1 else "update"
self.database[plant_id]["update_history"].append({
"date": datetime.now().isoformat(),
"num_images": int(len(valid_paths)),
"type": update_type
})
# 记录已录入的文件哈希
for img_path in valid_paths:
file_hash = self.calculate_file_hash(img_path)
if file_hash:
self.enrolled_file_hashes.add(file_hash)
# 保存已录入的文件哈希
self.save_enrolled_hashes()
self.save_database()
return True, f"植物 {plant_id} 录入成功,使用了 {len(valid_paths)} 张图像"
except Exception as e:
return False, f"录入过程中出错: {e}"
# 在plantFacing_1_09.py中修改recognize_plant方法
def recognize_plant(self, image_paths, threshold=0.7):
"""识别植物,返回完整相似度分布"""
if len(image_paths) < 1:
return None, "需要至少1张图像进行识别", 0.0, {}
try:
# 首先打印当前数据库状态
print(f"当前数据库中有 {len(self.database)} 个植物:")
for plant_id, data in self.database.items():
print(f" {plant_id}: {len(data['features'])} 个特征向量, {data['image_count']} 张图片")
# 提取特征(不检查重复,因为识别时可能使用已录入的图像)
features, valid_paths = self.extract_features(image_paths, check_duplicates=False)
if len(valid_paths) < 1:
return None, "没有有效的图像可用于识别", 0.0, {}
# 保存图像路径供后续使用
self.last_recognition_images = valid_paths.copy()
avg_feature = np.mean(features, axis=0)
# 在数据库中查找最相似的植物
best_match = None
best_similarity = 0
similarity_scores = {} # 记录所有相似度分数
similarity_distribution = {} # 新增:完整的相似度分布
for plant_id, plant_data in self.database.items():
similarities = []
for stored_feature in plant_data["features"]:
# 确保存储的特征是numpy数组
stored_feature_array = np.array(stored_feature, dtype=np.float32)
# 计算余弦相似度
dot_product = np.dot(avg_feature, stored_feature_array)
norm_a = np.linalg.norm(avg_feature)
norm_b = np.linalg.norm(stored_feature_array)
# 防止除以零
if norm_a < 1e-10 or norm_b < 1e-10:
similarity = 0.0
else:
similarity = dot_product / (norm_a * norm_b)
similarities.append(similarity)
# 取该植物的最高相似度
plant_similarity = max(similarities) if similarities else 0
similarity_scores[plant_id] = plant_similarity
similarity_distribution[plant_id] = plant_similarity # 存储每个植物的相似度
if plant_similarity > best_similarity:
best_similarity = plant_similarity
best_match = plant_id
# 打印所有相似度分数供调试
print("相似度分数:")
for plant_id, score in similarity_scores.items():
print(f" {plant_id}: {score:.4f}")
# 计算同种植物的组数
species_group_count = 0
if best_match and best_match in self.database:
species_group_count = self.get_species_group_count(best_match)
# 记录识别统计
recognition_record = {
'plant_id': 'unknown',
'recognition_time': datetime.now().isoformat(),
'num_images': int(len(valid_paths)),
'recognized_id': best_match if best_similarity >= threshold else 'unknown',
'confidence': float(best_similarity),
'admin_feedback': 'pending',
'true_id': None,
'correct': None,
'database_size': sum(plant_data["image_count"] for plant_data in self.database.values()),
'plant_image_count': self.database[best_match]["image_count"] if best_match in self.database else 0,
'species_group_count': int(species_group_count) # 新增:同种植物组数
}
# 添加到统计 DataFrame
self.stats = pd.concat([
self.stats,
pd.DataFrame([recognition_record])
], ignore_index=True)
self.last_recognition_index = len(self.stats) - 1 # 记录最近识别索引
self.last_recognition_result = best_match if best_similarity >= threshold else None
self.save_stats()
# 返回完整相似度分布
if best_similarity >= threshold:
return best_match, f"识别为植物 {best_match} (置信度: {best_similarity:.3f}, 同种组数: {species_group_count})", best_similarity, similarity_distribution
else:
return None, f"未找到匹配的植物 (最高相似度: {best_similarity:.3f})", best_similarity, similarity_distribution
except Exception as e:
return None, f"识别过程中出错: {e}", 0.0, {}
def confirm_recognition(self, plant_id=None, is_correct=True, enroll_images=False):
"""管理员确认识别结果(使用最近一次识别)"""
if self.last_recognition_index is None:
return False, "没有最近的识别记录"
try:
record = self.stats.iloc[self.last_recognition_index]
# 如果识别正确且未提供植物ID,使用识别结果
if is_correct and plant_id is None and self.last_recognition_result:
plant_id = self.last_recognition_result
# 保存当前数据库状态(用于统计)
current_db_size = sum(plant_data["image_count"] for plant_data in self.database.values())
current_plant_count = self.database[plant_id][
"image_count"] if plant_id and plant_id in self.database else 0
# 计算同种植物的组数
species_group_count = 0
if plant_id and plant_id in self.database:
species_group_count = self.get_species_group_count(plant_id)
# 如果选择录入图像,先更新统计信息,再录入
if enroll_images and plant_id is not None and self.last_recognition_images:
# 先更新统计信息(使用录入前的数据库状态)
if is_correct:
self.stats.at[self.last_recognition_index, 'admin_feedback'] = f"correct_{plant_id}"
self.stats.at[self.last_recognition_index, 'true_id'] = plant_id
self.stats.at[self.last_recognition_index, 'correct'] = True
else:
self.stats.at[self.last_recognition_index, 'admin_feedback'] = "incorrect"
self.stats.at[self.last_recognition_index, 'true_id'] = plant_id
self.stats.at[self.last_recognition_index, 'correct'] = False
self.stats.at[self.last_recognition_index, 'database_size'] = int(current_db_size)
self.stats.at[self.last_recognition_index, 'plant_image_count'] = int(current_plant_count)
self.stats.at[self.last_recognition_index, 'species_group_count'] = int(species_group_count)
# 保存统计信息
self.save_stats()
# 然后将图像加入到对应植物的特征库中
success, message = self.enroll_plant(plant_id, self.last_recognition_images)
if success:
print(f"已将图像加入到植物 {plant_id} 的特征库中: {message}")
else:
print(f"添加图像到特征库失败: {message}")
else:
# 不录入图像,直接更新统计信息
if is_correct:
self.stats.at[self.last_recognition_index, 'admin_feedback'] = f"correct_{plant_id}"
self.stats.at[self.last_recognition_index, 'true_id'] = plant_id
self.stats.at[self.last_recognition_index, 'correct'] = True
else:
self.stats.at[self.last_recognition_index, 'admin_feedback'] = "incorrect"
self.stats.at[self.last_recognition_index, 'true_id'] = plant_id
self.stats.at[self.last_recognition_index, 'correct'] = False
self.stats.at[self.last_recognition_index, 'database_size'] = int(current_db_size)
self.stats.at[self.last_recognition_index, 'plant_image_count'] = int(current_plant_count)
self.stats.at[self.last_recognition_index, 'species_group_count'] = int(species_group_count)
# 保存统计信息
self.save_stats()
# 更新植物历史记录(不重新初始化植物数据)
if is_correct and plant_id in self.database:
self.database[plant_id]["update_history"].append({
"date": datetime.now().isoformat(),
"num_images": int(record['num_images']),
"type": "confirmation_update"
})
self.save_database()
return True, "反馈已记录"
except Exception as e:
return False, f"确认过程中出错: {e}"
def list_plants(self):
"""列出所有已录入的植物"""
if not self.database:
print("数据库中暂无植物")
return
print("\n=== 已录入的植物 ===")
# 统计同种植物的组数
species_stats = {}
for plant_id, data in self.database.items():
if '#' in plant_id:
species = plant_id.split('#')[0]
else:
species = plant_id
if species not in species_stats:
species_stats[species] = 0
species_stats[species] += 1
print(f"编号: {plant_id}, 录入日期: {data['enrollment_date']}, 图像数量: {data['image_count']}")
# 显示同种植物组数统计
print("\n=== 同种植物组数统计 ===")
for species, count in species_stats.items():
print(f"{species}: {count} 组")
def show_accuracy_stats(self):
"""显示准确率统计"""
if len(self.stats) == 0:
print("暂无统计信息")
return
# 过滤出已确认的记录
confirmed_stats = self.stats[self.stats['correct'].notna()]
if len(confirmed_stats) == 0:
print("暂无已确认的识别记录")
return
# 计算总体准确率
total = len(confirmed_stats)
correct = len(confirmed_stats[confirmed_stats['correct'] == True])
accuracy = correct / total * 100 if total > 0 else 0
print(f"\n=== 识别准确率统计 ===")
print(f"总识别次数: {total}")
print(f"正确识别次数: {correct}")
print(f"总体准确率: {accuracy:.2f}%")
# 按图片数量分组统计
if 'num_images' in confirmed_stats.columns:
print("\n按图片数量分组的准确率:")
for num_images in sorted(confirmed_stats['num_images'].unique()):
group = confirmed_stats[confirmed_stats['num_images'] == num_images]
group_total = len(group)
group_correct = len(group[group['correct'] == True])
group_accuracy = group_correct / group_total * 100 if group_total > 0 else 0
print(f" {num_images}张图片: {group_accuracy:.2f}% ({group_correct}/{group_total})")
# 按数据库大小分组统计
if 'database_size' in confirmed_stats.columns:
print("\n按数据库大小分组的准确率:")
db_sizes = sorted(confirmed_stats['database_size'].unique())
for db_size in db_sizes:
group = confirmed_stats[confirmed_stats['database_size'] == db_size]
group_total = len(group)
group_correct = len(group[group['correct'] == True])
group_accuracy = group_correct / group_total * 100 if group_total > 0 else 0
print(f" 数据库大小{db_size}: {group_accuracy:.2f}% ({group_correct}/{group_total})")
# 按植物图片数量分组统计
if 'plant_image_count' in confirmed_stats.columns:
print("\n按植物图片数量分组的准确率:")
plant_counts = sorted(confirmed_stats['plant_image_count'].unique())
for count in plant_counts:
group = confirmed_stats[confirmed_stats['plant_image_count'] == count]
group_total = len(group)
group_correct = len(group[group['correct'] == True])
group_accuracy = group_correct / group_total * 100 if group_total > 0 else 0
print(f" 植物图片数{count}: {group_accuracy:.2f}% ({group_correct}/{group_total})")
# 新增:按同种植物组数分组统计
if 'species_group_count' in confirmed_stats.columns:
print("\n按同种植物组数分组的准确率:")
group_counts = sorted(confirmed_stats['species_group_count'].unique())
for count in group_counts:
group = confirmed_stats[confirmed_stats['species_group_count'] == count]
group_total = len(group)
group_correct = len(group[group['correct'] == True])
group_accuracy = group_correct / group_total * 100 if group_total > 0 else 0
print(f" 同种组数{count}: {group_accuracy:.2f}% ({group_correct}/{group_total})")
def clear_temp_folder(self):
"""清空临时文件夹"""
temp_dir = "imgTemp"
if not os.path.exists(temp_dir):
print(f"临时文件夹 {temp_dir} 不存在")
return False
try:
# 获取目录中的所有文件
files = [f for f in os.listdir(temp_dir) if os.path.isfile(os.path.join(temp_dir, f))]
if not files:
print("临时文件夹为空")
return True
# 删除所有文件
for file in files:
file_path = os.path.join(temp_dir, file)
os.remove(file_path)
print(f"已清空临时文件夹,删除了 {len(files)} 个文件")
return True
except Exception as e:
print(f"清空临时文件夹时出错: {e}")
return False
def validate_database(self):
"""验证数据库完整性"""
print("验证数据库完整性...")
for plant_id, data in self.database.items():
feature_count = len(data["features"])
print(f"植物 {plant_id}: {feature_count} 个特征向量, {data['image_count']} 张图片")
# 检查特征向量维度是否一致
if feature_count > 0:
first_dim = len(data["features"][0])
for i, feature in enumerate(data["features"]):
if len(feature) != first_dim:
print(f" 警告: 特征向量 {i} 的维度不一致")
# 检查特征向量是否已正确加载
print("\n检查特征向量格式:")
for plant_id, data in self.database.items():
if data["features"]:
sample_feature = np.array(data["features"][0])
print(f"植物 {plant_id}: 特征向量形状 {sample_feature.shape}, 数据类型 {sample_feature.dtype}")
def reset_database(self):
"""重置数据库(谨慎使用)"""
confirm = input("确定要重置数据库吗?所有数据将丢失!(y/n): ")
if confirm.lower() == 'y':
self.database = {}
self.save_database()
self.enrolled_file_hashes = set()
self.save_enrolled_hashes()
print("数据库已重置")
def test_feature_consistency(self, image_path):
"""测试特征提取的一致性"""
if not os.path.exists(image_path):
print(f"图像不存在: {image_path}")
return
# 提取特征多次,检查一致性
features = []
for i in range(3): # 减少测试次数以加快速度
try:
feature, _ = self.extract_features([image_path], check_duplicates=False)
if len(feature) > 0:
features.append(feature[0])
else:
print(f"第 {i + 1} 次特征提取失败: 没有提取到特征")
except Exception as e:
print(f"第 {i + 1} 次特征提取失败: {e}")
continue
if len(features) < 2:
print("特征提取失败,无法进行一致性测试")
return
# 计算特征之间的相似度
similarities = []
for i in range(len(features)):
for j in range(i + 1, len(features)):
dot_product = np.dot(features[i], features[j])
norm_a = np.linalg.norm(features[i])
norm_b = np.linalg.norm(features[j])
if norm_a < 1e-10 or norm_b < 1e-10:
similarity = 0.0
else:
similarity = dot_product / (norm_a * norm_b)
similarities.append(similarity)
avg_similarity = np.mean(similarities)
print(f"特征提取一致性测试: 平均相似度 = {avg_similarity:.6f}")
# 打印特征向量的前几个值
print("特征向量前5个值:")
for i, feature in enumerate(features):
print(f" 第 {i + 1} 次: {feature[:5]}")
return avg_similarity
# 交互式命令行界面
def main():
# 确保临时目录存在
os.makedirs("imgTemp", exist_ok=True)
os.makedirs("db", exist_ok=True)
# 询问模型路径
model_path = input("请输入模型文件路径(直接回车使用默认路径): ").strip()
if model_path == "":
model_path = "db/plant_model.h5"
system = PlantFaceRecognitionSystem(model_path=model_path)
print("\n=== 植物刷脸机系统 ===")
print("命令: enroll, recognize, confirm, stats, list, accuracy, clear, validate, reset, test, delete, export, exit")
print("提示: 请将图像放在 imgTemp 目录中")
while True:
command = input("\n请输入命令: ").strip().lower()
if command == "exit":
print("保存数据并退出系统...")
system.save_database()
system.save_stats()
system.save_enrolled_hashes()
break
elif command == "enroll" or command == "en":
plant_id = input("请输入植物编号: ").strip()
image_dir = "imgTemp"
if not os.path.exists(image_dir):
print("错误: 目录不存在")
continue
# 获取目录中的所有图像
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']
image_paths = [
os.path.join(image_dir, f) for f in os.listdir(image_dir)
if any(f.lower().endswith(ext) for ext in image_extensions)
]
if not image_paths:
print("错误: 目录中没有找到图像文件")
continue
success, message = system.enroll_plant(plant_id, image_paths)
print(message)
# 询问是否删除已处理的图像
if success and input("是否删除已处理的图像? (y/n): ").strip().lower() == 'y':
for img_path in image_paths:
try:
os.remove(img_path)
except:
pass
print("已删除图像文件")
elif command == "recognize" or command == "rec":
image_dir = "imgTemp"
if not os.path.exists(image_dir):
print("错误: 目录不存在")
continue
# 获取目录中的所有图像
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']
image_paths = [
os.path.join(image_dir, f) for f in os.listdir(image_dir)
if any(f.lower().endswith(ext) for ext in image_extensions)
]
if not image_paths:
print("错误: 目录中没有找到图像文件")
continue
plant_id, message, confidence, all_scores = system.recognize_plant(image_paths) #can get all the similarities, but we don't print in this version
print(message)
elif command == "confirm" or command == "con":
if system.last_recognition_index is None:
print("错误: 没有最近的识别记录,请先进行识别")
continue
# 显示最近识别记录
last_record = system.stats.iloc[system.last_recognition_index]
print(f"最近识别记录 (索引 {system.last_recognition_index}):")
print(f" 识别结果: {last_record['recognized_id']}")
print(f" 置信度: {last_record['confidence']:.3f}")
print(f" 图片数量: {last_record['num_images']}")
print(f" 同种植物组数: {last_record['species_group_count']}")
feedback = input("识别是否正确? (y/n): ").strip().lower()
if feedback == 'y':
# 询问是否将图像加入到数据库中
enroll_choice = input("是否将这些图像加入到数据库中? (y/n): ").strip().lower()
enroll_images = (enroll_choice == 'y')
# 使用识别结果作为植物编号
plant_id = system.last_recognition_result
print(f"使用识别结果作为植物编号: {plant_id}")
success, message = system.confirm_recognition(plant_id, True, enroll_images)
else:
# 获取正确的植物编号
plant_id = input("请输入正确的植物编号: ")
# 询问是否将图像加入到数据库中
enroll_choice = input("是否将这些图像加入到数据库中? (y/n): ").strip().lower()
enroll_images = (enroll_choice == 'y')
success, message = system.confirm_recognition(plant_id, False, enroll_images)
print(message)
elif command == "stats":
print("\n=== 识别统计 ===")
print(system.stats)
elif command == "check_model":
system.check_model_files()
elif command == "list" or command == "ls":
system.list_plants()
elif command == "accuracy":
system.show_accuracy_stats()
elif command == "clear" or command == "cl":
success = system.clear_temp_folder()
if success:
print("临时文件夹已清空")
elif command == "validate":
system.validate_database()
elif command == "reset":
system.reset_database()
elif command == "test":
# 尝试使用imgTemp目录中的第一个图像进行测试
image_dir = "imgTemp"
if os.path.exists(image_dir):
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']
image_files = [
f for f in os.listdir(image_dir)
if any(f.lower().endswith(ext) for ext in image_extensions)
]
if image_files:
image_path = os.path.join(image_dir, image_files[0])
print(f"使用图像 {image_files[0]} 进行测试")
system.test_feature_consistency(image_path)
else:
print("imgTemp目录中没有图像文件")
else:
print("imgTemp目录不存在")
elif command == "delete":
plant_id = input("请输入要删除的植物编号: ").strip()
success, message = system.delete_plant(plant_id)
print(message)
elif command == "export":
success = system.export_database_info()
if success:
print("数据库信息导出成功")
else:
print(
"未知命令,请使用: enroll, recognize, confirm, stats, list, accuracy, clear, validate, reset, test, delete, export, exit")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment