Created
November 29, 2025 06:17
-
-
Save MarchLiu/b1a4e1504a5b1d4ee97a3a92f56166a3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # This is a plant face rec machine | |
| # The class PlantFaceRecognitionSystem can be imported by other programs | |
| import os | |
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # 减少TensorFlow日志输出 | |
| # 尝试不同的导入方式 | |
| try: | |
| import tensorflow as tf | |
| from tensorflow.keras.models import load_model, Model, model_from_json | |
| from tensorflow.keras.applications import MobileNetV2 | |
| from tensorflow.keras.layers import GlobalAveragePooling2D, Dense | |
| from tensorflow.keras.preprocessing import image | |
| from tensorflow.keras.applications.mobilenet_v2 import preprocess_input | |
| TENSORFLOW_AVAILABLE = True | |
| except ImportError as e: | |
| print(f"TensorFlow导入错误: {e}") | |
| # 尝试备用导入 | |
| try: | |
| import tensorflow as tf | |
| from tensorflow import keras | |
| from keras.models import load_model, Model, model_from_json | |
| from keras.applications import MobileNetV2 | |
| from keras.layers import GlobalAveragePooling2D, Dense | |
| from keras.preprocessing import image | |
| from keras.applications.mobilenet_v2 import preprocess_input | |
| except ImportError: | |
| print("所有TensorFlow导入都失败") | |
| exit(1) | |
| import json | |
| import numpy as np | |
| import pandas as pd | |
| from datetime import datetime | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| #import tensorflow as tf | |
| #from tensorflow.keras.models import load_model, Model, model_from_json | |
| #from tensorflow.keras.applications import MobileNetV2 | |
| #from tensorflow.keras.layers import GlobalAveragePooling2D, Dense | |
| #from tensorflow.keras.preprocessing import image | |
| #from tensorflow.keras.applications.mobilenet_v2 import preprocess_input | |
| import warnings | |
| import shutil | |
| import hashlib | |
| import base64 | |
| import random | |
| warnings.filterwarnings('ignore') | |
| # 设置随机种子以确保可重复性 | |
| def set_seed(seed=42): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| tf.random.set_seed(seed) | |
| set_seed(42) # 在程序开始处设置随机种子 | |
| # 设置工作目录 | |
| os.chdir(".../working/dir") | |
| # This class can be imported by other codes, especially for batch test. | |
| class PlantFaceRecognitionSystem: | |
| def __init__(self, database_path="db/plant_faces.json", model_path="db/plant_model.h5"): | |
| # 设置图像尺寸 | |
| self.img_height = 224 | |
| self.img_width = 224 | |
| # 确保数据库目录存在 | |
| os.makedirs(os.path.dirname(database_path), exist_ok=True) | |
| os.makedirs(os.path.dirname(model_path), exist_ok=True) | |
| # 初始化数据库和统计信息 | |
| self.database_path = database_path | |
| self.model_path = model_path | |
| self.database = {} | |
| self.stats_path = "db/recognition_stats.csv" | |
| self.stats = pd.DataFrame() | |
| self.last_recognition_images = [] | |
| self.last_recognition_index = None | |
| self.last_recognition_result = None | |
| self.enrolled_file_hashes = set() | |
| # 先初始化feature_extractor为None | |
| self.feature_extractor = None | |
| # 加载或创建特征提取模型 | |
| self.load_or_create_model(model_path) | |
| # 加载数据库和其他数据 | |
| self.load_database() | |
| self.load_stats() | |
| self.load_enrolled_hashes() | |
| # 测试模型是否正常工作 | |
| self.test_model_functionality() | |
| def get_species_group_count(self, plant_id): | |
| """计算同种植物的组数(根据植物名称)""" | |
| if '#' not in plant_id: | |
| # 如果没有#符号,说明只有一种植物 | |
| return 1 | |
| species_name = plant_id.split('#')[0] | |
| group_count = 0 | |
| for existing_id in self.database.keys(): | |
| if existing_id.startswith(species_name + '#'): | |
| group_count += 1 | |
| return group_count | |
| def load_or_create_model(self, model_path): | |
| """加载现有模型或创建新的特征提取模型""" | |
| # 检查模型文件是否存在 | |
| json_path = model_path.replace('.h5', '.json') | |
| if os.path.exists(model_path) and os.path.exists(json_path): | |
| print(f"加载现有模型: {model_path}") | |
| try: | |
| # 设置相同的随机种子以确保一致性 | |
| set_seed(42) | |
| # 加载模型架构 | |
| with open(json_path, 'r') as json_file: | |
| model_json = json_file.read() | |
| model = model_from_json(model_json) | |
| # 加载模型权重 | |
| model.load_weights(model_path) | |
| print("模型加载成功") | |
| self.feature_extractor = model | |
| return | |
| except Exception as e: | |
| print(f"加载模型失败: {e}") | |
| print("将创建新的特征提取模型") | |
| else: | |
| print("模型文件不存在,创建新的特征提取模型...") | |
| # 使用预训练的MobileNetV2作为基础 | |
| try: | |
| # 设置相同的随机种子以确保一致性 | |
| set_seed(42) | |
| base_model = MobileNetV2( | |
| input_shape=(self.img_height, self.img_width, 3), | |
| include_top=False, | |
| weights='imagenet' | |
| ) | |
| base_model.trainable = False # 冻结基础模型权重 | |
| # 添加全局平均池化层 | |
| inputs = tf.keras.Input(shape=(self.img_height, self.img_width, 3)) | |
| x = base_model(inputs, training=False) # 重要:设置training=False | |
| x = GlobalAveragePooling2D()(x) | |
| model = Model(inputs, x) | |
| self.feature_extractor = model | |
| print("特征提取模型创建成功") | |
| # 自动保存新创建的模型 | |
| self.save_model() | |
| except Exception as e: | |
| print(f"创建特征提取模型失败: {e}") | |
| # 返回一个简单的模型作为备用 | |
| self.create_fallback_model() | |
| def create_fallback_model(self): | |
| """创建备用模型(当主模型失败时使用)""" | |
| print("创建备用模型...") | |
| try: | |
| # 设置相同的随机种子以确保一致性 | |
| set_seed(42) | |
| # 创建一个简单的CNN模型作为备用 | |
| model = tf.keras.Sequential([ | |
| tf.keras.layers.Conv2D(32, (3, 3), activation='relu', | |
| input_shape=(self.img_height, self.img_width, 3), | |
| kernel_initializer='glorot_uniform'), | |
| tf.keras.layers.MaxPooling2D((2, 2)), | |
| tf.keras.layers.Conv2D(64, (3, 3), activation='relu', | |
| kernel_initializer='glorot_uniform'), | |
| tf.keras.layers.MaxPooling2D((2, 2)), | |
| tf.keras.layers.Conv2D(64, (3, 3), activation='relu', | |
| kernel_initializer='glorot_uniform'), | |
| tf.keras.layers.GlobalAveragePooling2D(), | |
| ]) | |
| self.feature_extractor = model | |
| print("备用模型创建成功") | |
| # 自动保存备用模型 | |
| self.save_model() | |
| except Exception as e: | |
| print(f"创建备用模型失败: {e}") | |
| # 返回一个最简单的模型 | |
| model = tf.keras.Sequential([ | |
| tf.keras.layers.GlobalAveragePooling2D(input_shape=(self.img_height, self.img_width, 3)), | |
| ]) | |
| self.feature_extractor = model | |
| # 自动保存最简单的模型 | |
| self.save_model() | |
| def save_model(self, model_path=None): | |
| """保存特征提取模型""" | |
| if model_path is None: | |
| model_path = self.model_path | |
| # 检查feature_extractor是否存在 | |
| if self.feature_extractor is None: | |
| print("错误: feature_extractor 尚未初始化,无法保存模型") | |
| return False | |
| try: | |
| # 确保目录存在 | |
| os.makedirs(os.path.dirname(model_path), exist_ok=True) | |
| print(f"正在保存模型到: {model_path}") | |
| print(f"模型类型: {type(self.feature_extractor)}") | |
| print(f"模型层数: {len(self.feature_extractor.layers)}") | |
| # 尝试保存完整模型 | |
| try: | |
| self.feature_extractor.save(model_path) | |
| print(f"完整模型已保存到 {model_path}") | |
| # 同时保存模型架构为JSON | |
| json_path = model_path.replace('.h5', '.json') | |
| with open(json_path, 'w') as json_file: | |
| json_file.write(self.feature_extractor.to_json()) | |
| print(f"模型架构已保存到 {json_path}") | |
| except Exception as e: | |
| print(f"保存完整模型失败: {e}") | |
| return False | |
| # 检查文件是否确实创建 | |
| if os.path.exists(model_path): | |
| file_size = os.path.getsize(model_path) | |
| print(f"模型文件已创建,大小: {file_size} 字节") | |
| return True | |
| else: | |
| print("错误: 模型文件未创建") | |
| return False | |
| except Exception as e: | |
| print(f"保存模型时出错: {e}") | |
| return False | |
| # 其他方法保持不变... | |
| def extract_features(self, image_paths, check_duplicates=True): | |
| """从一组图像中提取特征""" | |
| features = [] | |
| valid_paths = [] | |
| for img_path in image_paths: | |
| if not os.path.exists(img_path): | |
| print(f"警告: 图像不存在 {img_path}") | |
| continue | |
| # 检查文件是否已录入过(仅在需要时) | |
| if check_duplicates: | |
| file_hash = self.calculate_file_hash(img_path) | |
| if file_hash in self.enrolled_file_hashes: | |
| print(f"跳过已录入图像: {os.path.basename(img_path)}") | |
| continue | |
| try: | |
| # 加载和预处理图像 | |
| img = image.load_img(img_path, target_size=(self.img_height, self.img_width)) | |
| img_array = image.img_to_array(img) | |
| img_array = np.expand_dims(img_array, axis=0) | |
| # 使用MobileNetV2的预处理函数 | |
| img_array = preprocess_input(img_array) | |
| # 提取特征 - 确保确定性操作 | |
| feature = self.feature_extractor.predict(img_array, verbose=0) | |
| # 对特征向量进行归一化,提高区分度 | |
| feature = feature.flatten() | |
| feature_norm = np.linalg.norm(feature) | |
| # 检查特征是否全为零 | |
| if feature_norm < 1e-10: # 接近零 | |
| print(f"警告: 图像 {os.path.basename(img_path)} 的特征范数接近零: {feature_norm}") | |
| print("可能的原因: 模型未正确工作或图像预处理有问题") | |
| # 跳过这个特征,但继续处理其他图像 | |
| continue | |
| feature = feature / feature_norm # L2归一化 | |
| features.append(feature) | |
| valid_paths.append(img_path) | |
| except Exception as e: | |
| print(f"处理图像 {img_path} 时出错: {e}") | |
| if not features: | |
| raise ValueError("没有成功提取任何图像特征") | |
| return np.array(features), valid_paths | |
| def test_model_functionality(self): | |
| """测试模型功能是否正常""" | |
| print("测试模型功能...") | |
| try: | |
| # 创建一个测试图像(全零图像) | |
| test_image = np.zeros((self.img_height, self.img_width, 3), dtype=np.float32) | |
| test_image = np.expand_dims(test_image, axis=0) | |
| test_image = preprocess_input(test_image) | |
| # 提取特征 | |
| test_features = self.feature_extractor.predict(test_image, verbose=0) | |
| test_features = test_features.flatten() | |
| # 检查特征是否全为零 | |
| if np.all(test_features == 0): | |
| print("警告: 模型提取的特征全为零!") | |
| print("可能的原因: 模型未正确初始化或权重未加载") | |
| else: | |
| print(f"模型测试通过,特征范数: {np.linalg.norm(test_features):.6f}") | |
| print(f"特征前5个值: {test_features[:5]}") | |
| except Exception as e: | |
| print(f"模型测试失败: {e}") | |
| def load_enrolled_hashes(self): | |
| """加载已录入文件的哈希值""" | |
| hash_file = "db/enrolled_hashes.txt" | |
| if os.path.exists(hash_file): | |
| try: | |
| with open(hash_file, 'r') as f: | |
| self.enrolled_file_hashes = set(line.strip() for line in f) | |
| print(f"已加载 {len(self.enrolled_file_hashes)} 个已录入文件的哈希值") | |
| except Exception as e: | |
| print(f"加载哈希文件时出错: {e}") | |
| self.enrolled_file_hashes = set() | |
| def save_enrolled_hashes(self): | |
| """保存已录入文件的哈希值""" | |
| hash_file = "db/enrolled_hashes.txt" | |
| try: | |
| with open(hash_file, 'w') as f: | |
| for file_hash in self.enrolled_file_hashes: | |
| f.write(file_hash + '\n') | |
| except Exception as e: | |
| print(f"保存哈希文件时出错: {e}") | |
| def calculate_file_hash(self, file_path): | |
| """计算文件的MD5哈希""" | |
| hash_md5 = hashlib.md5() | |
| try: | |
| with open(file_path, "rb") as f: | |
| for chunk in iter(lambda: f.read(4096), b""): | |
| hash_md5.update(chunk) | |
| return hash_md5.hexdigest() | |
| except Exception as e: | |
| print(f"计算文件哈希时出错 {file_path}: {e}") | |
| return None | |
| def load_database(self): | |
| """加载植物脸数据库""" | |
| if os.path.exists(self.database_path): | |
| try: | |
| with open(self.database_path, 'r') as f: | |
| content = f.read().strip() | |
| if not content: | |
| print("数据库文件为空,创建新的空数据库") | |
| self.database = {} | |
| return | |
| self.database = json.loads(content) | |
| if not self.validate_database_structure(): | |
| print("数据库结构验证失败,将尝试修复") | |
| self.try_repair_database() | |
| return | |
| # 处理每个植物的数据 | |
| for plant_id in list(self.database.keys()): | |
| if not plant_id or plant_id.strip() == "": | |
| print(f"警告: 发现空的植物ID,已删除") | |
| del self.database[plant_id] | |
| continue | |
| try: | |
| self.database[plant_id]["image_count"] = int(self.database[plant_id]["image_count"]) | |
| if "features" in self.database[plant_id]: | |
| if isinstance(self.database[plant_id]["features"][0], str): | |
| self.database[plant_id]["features"] = [ | |
| np.frombuffer(base64.b64decode(feature_str), dtype=np.float32).tolist() | |
| for feature_str in self.database[plant_id]["features"] | |
| ] | |
| else: | |
| self.database[plant_id]["features"] = [ | |
| [float(x) for x in feature] | |
| for feature in self.database[plant_id]["features"] | |
| ] | |
| except Exception as e: | |
| print(f"处理植物 {plant_id} 的数据时出错: {e}") | |
| del self.database[plant_id] | |
| print(f"已加载数据库,包含 {len(self.database)} 个植物") | |
| except Exception as e: | |
| print(f"加载数据库时出错: {e}") | |
| self.try_repair_database() | |
| else: | |
| self.database = {} | |
| print("创建新的空数据库") | |
| if self.database: | |
| self.save_database() | |
| def validate_database_structure(self): | |
| """验证数据库结构是否正确""" | |
| if not isinstance(self.database, dict): | |
| return False | |
| for plant_id, plant_data in self.database.items(): | |
| if not isinstance(plant_data, dict): | |
| return False | |
| required_keys = ["features", "enrollment_date", "update_history", "image_count"] | |
| for key in required_keys: | |
| if key not in plant_data: | |
| return False | |
| return True | |
| # 添加一个检查模型文件的命令 | |
| def check_model_files(self): | |
| """检查模型文件状态""" | |
| print("=== 模型文件检查 ===") | |
| # 检查主模型文件 | |
| if os.path.exists(self.model_path): | |
| size = os.path.getsize(self.model_path) | |
| print(f"主模型文件: {self.model_path} (大小: {size} 字节)") | |
| else: | |
| print(f"主模型文件: {self.model_path} (不存在)") | |
| # 检查可能的其他模型文件 | |
| possible_files = [ | |
| self.model_path.replace('.h5', '_architecture.json'), | |
| self.model_path.replace('.h5', '_weights.h5'), | |
| ] | |
| for file_path in possible_files: | |
| if os.path.exists(file_path): | |
| size = os.path.getsize(file_path) | |
| print(f"辅助文件: {file_path} (大小: {size} 字节)") | |
| print("=== 模型信息 ===") | |
| print(f"模型层数: {len(self.feature_extractor.layers)}") | |
| print(f"输入形状: {self.feature_extractor.input_shape}") | |
| print(f"输出形状: {self.feature_extractor.output_shape}") | |
| def try_repair_database(self): | |
| """尝试修复损坏的数据库""" | |
| try: | |
| # 备份原文件 | |
| if os.path.exists(self.database_path): | |
| backup_path = self.database_path + ".backup" | |
| shutil.copy2(self.database_path, backup_path) | |
| print(f"已创建数据库备份: {backup_path}") | |
| # 尝试逐行读取JSON文件 | |
| with open(self.database_path, 'r') as f: | |
| content = f.read() | |
| # 尝试修复常见的JSON格式错误 | |
| # 1. 删除尾随逗号 | |
| content = content.replace(',\n}', '\n}').replace(',\n]', '\n]') | |
| # 2. 尝试解析修复后的内容 | |
| self.database = json.loads(content) | |
| print("数据库修复成功") | |
| # 保存修复后的数据库 | |
| self.save_database() | |
| except Exception as e: | |
| print(f"数据库修复失败: {e}") | |
| print("将创建新的空数据库") | |
| self.database = {} | |
| # 不立即保存,等待有有效数据时再保存 | |
| def delete_plant(self, plant_id): | |
| """删除指定的植物条目""" | |
| if plant_id not in self.database: | |
| return False, f"植物 {plant_id} 不存在" | |
| # 确认删除 | |
| confirm = input(f"确定要删除植物 {plant_id} 吗?此操作不可撤销!(y/n): ").strip().lower() | |
| if confirm != 'y': | |
| return False, "取消删除" | |
| # 从数据库中删除 | |
| del self.database[plant_id] | |
| # 保存数据库 | |
| self.save_database() | |
| return True, f"植物 {plant_id} 已删除" | |
| def export_database_info(self): | |
| """导出数据库信息到文本文件""" | |
| try: | |
| export_path = "db/database_info.txt" | |
| with open(export_path, 'w') as f: | |
| f.write("=== 植物刷脸机数据库信息 ===\n\n") | |
| f.write(f"导出时间: {datetime.now().isoformat()}\n") | |
| f.write(f"植物数量: {len(self.database)}\n\n") | |
| # 添加同种植物组数统计 | |
| species_groups = {} | |
| for plant_id in self.database.keys(): | |
| if '#' in plant_id: | |
| species = plant_id.split('#')[0] | |
| else: | |
| species = plant_id | |
| if species not in species_groups: | |
| species_groups[species] = 0 | |
| species_groups[species] += 1 | |
| f.write("=== 同种植物组数统计 ===\n") | |
| for species, count in species_groups.items(): | |
| f.write(f"{species}: {count} 组\n") | |
| f.write("\n") | |
| for plant_id, data in self.database.items(): | |
| f.write(f"植物编号: {plant_id}\n") | |
| f.write(f" 录入日期: {data['enrollment_date']}\n") | |
| f.write(f" 图像数量: {data['image_count']}\n") | |
| f.write(f" 特征向量数量: {len(data['features'])}\n") | |
| f.write(f" 更新历史: {len(data['update_history'])} 次\n\n") | |
| print(f"数据库信息已导出到 {export_path}") | |
| return True | |
| except Exception as e: | |
| print(f"导出数据库信息时出错: {e}") | |
| return False | |
| # 其他方法保持不变... | |
| def save_database(self): | |
| """保存植物脸数据库""" | |
| try: | |
| # 确保所有数值都是Python原生类型 | |
| serializable_database = {} | |
| for plant_id, plant_data in self.database.items(): | |
| # 修复关键问题:使用Base64编码保存特征向量,避免精度损失 | |
| encoded_features = [ | |
| base64.b64encode(np.array(feature, dtype=np.float32).tobytes()).decode('utf-8') | |
| for feature in plant_data["features"] | |
| ] | |
| serializable_database[plant_id] = { | |
| "features": encoded_features, | |
| "enrollment_date": plant_data["enrollment_date"], | |
| "update_history": plant_data["update_history"], | |
| "image_count": int(plant_data["image_count"]) # 转换为Python int | |
| } | |
| with open(self.database_path, 'w') as f: | |
| json.dump(serializable_database, f, indent=4) | |
| print(f"数据库已保存到 {self.database_path}") | |
| except Exception as e: | |
| print(f"保存数据库时出错: {e}") | |
| def load_stats(self): | |
| """加载识别统计信息""" | |
| if os.path.exists(self.stats_path): | |
| try: | |
| self.stats = pd.read_csv(self.stats_path, index_col=0) | |
| # 确保列存在 | |
| for col in ['true_id', 'correct', 'database_size', 'plant_image_count', 'species_group_count']: | |
| if col not in self.stats.columns: | |
| self.stats[col] = None | |
| print(f"已加载统计信息,包含 {len(self.stats)} 条记录") | |
| except Exception as e: | |
| print(f"加载统计信息时出错: {e}") | |
| self.stats = pd.DataFrame(columns=[ | |
| 'plant_id', 'recognition_time', 'num_images', | |
| 'recognized_id', 'confidence', 'admin_feedback', | |
| 'true_id', 'correct', 'database_size', 'plant_image_count', 'species_group_count' | |
| ]) | |
| else: | |
| self.stats = pd.DataFrame(columns=[ | |
| 'plant_id', 'recognition_time', 'num_images', | |
| 'recognized_id', 'confidence', 'admin_feedback', | |
| 'true_id', 'correct', 'database_size', 'plant_image_count', 'species_group_count' | |
| ]) | |
| print("创建新的统计表格") | |
| def save_stats(self): | |
| """保存识别统计信息""" | |
| try: | |
| # 确保所有数值都是Python原生类型 | |
| self.stats = self.stats.applymap( | |
| lambda x: float(x) if isinstance(x, (np.integer, np.floating)) else x | |
| ) | |
| self.stats.to_csv(self.stats_path) | |
| print(f"统计信息已保存到 {self.stats_path}") | |
| except Exception as e: | |
| print(f"保存统计信息时出错: {e}") | |
| def enroll_plant(self, plant_id, image_paths): | |
| """录入新的植物脸""" | |
| if len(image_paths) < 1: | |
| return False, "需要至少1张图像进行录入" | |
| try: | |
| # 提取特征(检查重复) | |
| features, valid_paths = self.extract_features(image_paths, check_duplicates=True) | |
| if len(valid_paths) < 1: | |
| return False, "没有有效的图像可用于录入" | |
| # 计算平均特征向量 | |
| avg_feature = np.mean(features, axis=0) | |
| # 存储到数据库 - 修复关键错误 | |
| if plant_id not in self.database: | |
| self.database[plant_id] = { | |
| "features": [], | |
| "enrollment_date": datetime.now().isoformat(), | |
| "update_history": [], | |
| "image_count": 0 | |
| } | |
| # 正确追加新特征(不是覆盖) | |
| self.database[plant_id]["features"].append(avg_feature.tolist()) | |
| self.database[plant_id]["image_count"] += len(valid_paths) | |
| # 更新历史记录 | |
| update_type = "enrollment" if len(self.database[plant_id]["features"]) == 1 else "update" | |
| self.database[plant_id]["update_history"].append({ | |
| "date": datetime.now().isoformat(), | |
| "num_images": int(len(valid_paths)), | |
| "type": update_type | |
| }) | |
| # 记录已录入的文件哈希 | |
| for img_path in valid_paths: | |
| file_hash = self.calculate_file_hash(img_path) | |
| if file_hash: | |
| self.enrolled_file_hashes.add(file_hash) | |
| # 保存已录入的文件哈希 | |
| self.save_enrolled_hashes() | |
| self.save_database() | |
| return True, f"植物 {plant_id} 录入成功,使用了 {len(valid_paths)} 张图像" | |
| except Exception as e: | |
| return False, f"录入过程中出错: {e}" | |
| # 在plantFacing_1_09.py中修改recognize_plant方法 | |
| def recognize_plant(self, image_paths, threshold=0.7): | |
| """识别植物,返回完整相似度分布""" | |
| if len(image_paths) < 1: | |
| return None, "需要至少1张图像进行识别", 0.0, {} | |
| try: | |
| # 首先打印当前数据库状态 | |
| print(f"当前数据库中有 {len(self.database)} 个植物:") | |
| for plant_id, data in self.database.items(): | |
| print(f" {plant_id}: {len(data['features'])} 个特征向量, {data['image_count']} 张图片") | |
| # 提取特征(不检查重复,因为识别时可能使用已录入的图像) | |
| features, valid_paths = self.extract_features(image_paths, check_duplicates=False) | |
| if len(valid_paths) < 1: | |
| return None, "没有有效的图像可用于识别", 0.0, {} | |
| # 保存图像路径供后续使用 | |
| self.last_recognition_images = valid_paths.copy() | |
| avg_feature = np.mean(features, axis=0) | |
| # 在数据库中查找最相似的植物 | |
| best_match = None | |
| best_similarity = 0 | |
| similarity_scores = {} # 记录所有相似度分数 | |
| similarity_distribution = {} # 新增:完整的相似度分布 | |
| for plant_id, plant_data in self.database.items(): | |
| similarities = [] | |
| for stored_feature in plant_data["features"]: | |
| # 确保存储的特征是numpy数组 | |
| stored_feature_array = np.array(stored_feature, dtype=np.float32) | |
| # 计算余弦相似度 | |
| dot_product = np.dot(avg_feature, stored_feature_array) | |
| norm_a = np.linalg.norm(avg_feature) | |
| norm_b = np.linalg.norm(stored_feature_array) | |
| # 防止除以零 | |
| if norm_a < 1e-10 or norm_b < 1e-10: | |
| similarity = 0.0 | |
| else: | |
| similarity = dot_product / (norm_a * norm_b) | |
| similarities.append(similarity) | |
| # 取该植物的最高相似度 | |
| plant_similarity = max(similarities) if similarities else 0 | |
| similarity_scores[plant_id] = plant_similarity | |
| similarity_distribution[plant_id] = plant_similarity # 存储每个植物的相似度 | |
| if plant_similarity > best_similarity: | |
| best_similarity = plant_similarity | |
| best_match = plant_id | |
| # 打印所有相似度分数供调试 | |
| print("相似度分数:") | |
| for plant_id, score in similarity_scores.items(): | |
| print(f" {plant_id}: {score:.4f}") | |
| # 计算同种植物的组数 | |
| species_group_count = 0 | |
| if best_match and best_match in self.database: | |
| species_group_count = self.get_species_group_count(best_match) | |
| # 记录识别统计 | |
| recognition_record = { | |
| 'plant_id': 'unknown', | |
| 'recognition_time': datetime.now().isoformat(), | |
| 'num_images': int(len(valid_paths)), | |
| 'recognized_id': best_match if best_similarity >= threshold else 'unknown', | |
| 'confidence': float(best_similarity), | |
| 'admin_feedback': 'pending', | |
| 'true_id': None, | |
| 'correct': None, | |
| 'database_size': sum(plant_data["image_count"] for plant_data in self.database.values()), | |
| 'plant_image_count': self.database[best_match]["image_count"] if best_match in self.database else 0, | |
| 'species_group_count': int(species_group_count) # 新增:同种植物组数 | |
| } | |
| # 添加到统计 DataFrame | |
| self.stats = pd.concat([ | |
| self.stats, | |
| pd.DataFrame([recognition_record]) | |
| ], ignore_index=True) | |
| self.last_recognition_index = len(self.stats) - 1 # 记录最近识别索引 | |
| self.last_recognition_result = best_match if best_similarity >= threshold else None | |
| self.save_stats() | |
| # 返回完整相似度分布 | |
| if best_similarity >= threshold: | |
| return best_match, f"识别为植物 {best_match} (置信度: {best_similarity:.3f}, 同种组数: {species_group_count})", best_similarity, similarity_distribution | |
| else: | |
| return None, f"未找到匹配的植物 (最高相似度: {best_similarity:.3f})", best_similarity, similarity_distribution | |
| except Exception as e: | |
| return None, f"识别过程中出错: {e}", 0.0, {} | |
| def confirm_recognition(self, plant_id=None, is_correct=True, enroll_images=False): | |
| """管理员确认识别结果(使用最近一次识别)""" | |
| if self.last_recognition_index is None: | |
| return False, "没有最近的识别记录" | |
| try: | |
| record = self.stats.iloc[self.last_recognition_index] | |
| # 如果识别正确且未提供植物ID,使用识别结果 | |
| if is_correct and plant_id is None and self.last_recognition_result: | |
| plant_id = self.last_recognition_result | |
| # 保存当前数据库状态(用于统计) | |
| current_db_size = sum(plant_data["image_count"] for plant_data in self.database.values()) | |
| current_plant_count = self.database[plant_id][ | |
| "image_count"] if plant_id and plant_id in self.database else 0 | |
| # 计算同种植物的组数 | |
| species_group_count = 0 | |
| if plant_id and plant_id in self.database: | |
| species_group_count = self.get_species_group_count(plant_id) | |
| # 如果选择录入图像,先更新统计信息,再录入 | |
| if enroll_images and plant_id is not None and self.last_recognition_images: | |
| # 先更新统计信息(使用录入前的数据库状态) | |
| if is_correct: | |
| self.stats.at[self.last_recognition_index, 'admin_feedback'] = f"correct_{plant_id}" | |
| self.stats.at[self.last_recognition_index, 'true_id'] = plant_id | |
| self.stats.at[self.last_recognition_index, 'correct'] = True | |
| else: | |
| self.stats.at[self.last_recognition_index, 'admin_feedback'] = "incorrect" | |
| self.stats.at[self.last_recognition_index, 'true_id'] = plant_id | |
| self.stats.at[self.last_recognition_index, 'correct'] = False | |
| self.stats.at[self.last_recognition_index, 'database_size'] = int(current_db_size) | |
| self.stats.at[self.last_recognition_index, 'plant_image_count'] = int(current_plant_count) | |
| self.stats.at[self.last_recognition_index, 'species_group_count'] = int(species_group_count) | |
| # 保存统计信息 | |
| self.save_stats() | |
| # 然后将图像加入到对应植物的特征库中 | |
| success, message = self.enroll_plant(plant_id, self.last_recognition_images) | |
| if success: | |
| print(f"已将图像加入到植物 {plant_id} 的特征库中: {message}") | |
| else: | |
| print(f"添加图像到特征库失败: {message}") | |
| else: | |
| # 不录入图像,直接更新统计信息 | |
| if is_correct: | |
| self.stats.at[self.last_recognition_index, 'admin_feedback'] = f"correct_{plant_id}" | |
| self.stats.at[self.last_recognition_index, 'true_id'] = plant_id | |
| self.stats.at[self.last_recognition_index, 'correct'] = True | |
| else: | |
| self.stats.at[self.last_recognition_index, 'admin_feedback'] = "incorrect" | |
| self.stats.at[self.last_recognition_index, 'true_id'] = plant_id | |
| self.stats.at[self.last_recognition_index, 'correct'] = False | |
| self.stats.at[self.last_recognition_index, 'database_size'] = int(current_db_size) | |
| self.stats.at[self.last_recognition_index, 'plant_image_count'] = int(current_plant_count) | |
| self.stats.at[self.last_recognition_index, 'species_group_count'] = int(species_group_count) | |
| # 保存统计信息 | |
| self.save_stats() | |
| # 更新植物历史记录(不重新初始化植物数据) | |
| if is_correct and plant_id in self.database: | |
| self.database[plant_id]["update_history"].append({ | |
| "date": datetime.now().isoformat(), | |
| "num_images": int(record['num_images']), | |
| "type": "confirmation_update" | |
| }) | |
| self.save_database() | |
| return True, "反馈已记录" | |
| except Exception as e: | |
| return False, f"确认过程中出错: {e}" | |
| def list_plants(self): | |
| """列出所有已录入的植物""" | |
| if not self.database: | |
| print("数据库中暂无植物") | |
| return | |
| print("\n=== 已录入的植物 ===") | |
| # 统计同种植物的组数 | |
| species_stats = {} | |
| for plant_id, data in self.database.items(): | |
| if '#' in plant_id: | |
| species = plant_id.split('#')[0] | |
| else: | |
| species = plant_id | |
| if species not in species_stats: | |
| species_stats[species] = 0 | |
| species_stats[species] += 1 | |
| print(f"编号: {plant_id}, 录入日期: {data['enrollment_date']}, 图像数量: {data['image_count']}") | |
| # 显示同种植物组数统计 | |
| print("\n=== 同种植物组数统计 ===") | |
| for species, count in species_stats.items(): | |
| print(f"{species}: {count} 组") | |
| def show_accuracy_stats(self): | |
| """显示准确率统计""" | |
| if len(self.stats) == 0: | |
| print("暂无统计信息") | |
| return | |
| # 过滤出已确认的记录 | |
| confirmed_stats = self.stats[self.stats['correct'].notna()] | |
| if len(confirmed_stats) == 0: | |
| print("暂无已确认的识别记录") | |
| return | |
| # 计算总体准确率 | |
| total = len(confirmed_stats) | |
| correct = len(confirmed_stats[confirmed_stats['correct'] == True]) | |
| accuracy = correct / total * 100 if total > 0 else 0 | |
| print(f"\n=== 识别准确率统计 ===") | |
| print(f"总识别次数: {total}") | |
| print(f"正确识别次数: {correct}") | |
| print(f"总体准确率: {accuracy:.2f}%") | |
| # 按图片数量分组统计 | |
| if 'num_images' in confirmed_stats.columns: | |
| print("\n按图片数量分组的准确率:") | |
| for num_images in sorted(confirmed_stats['num_images'].unique()): | |
| group = confirmed_stats[confirmed_stats['num_images'] == num_images] | |
| group_total = len(group) | |
| group_correct = len(group[group['correct'] == True]) | |
| group_accuracy = group_correct / group_total * 100 if group_total > 0 else 0 | |
| print(f" {num_images}张图片: {group_accuracy:.2f}% ({group_correct}/{group_total})") | |
| # 按数据库大小分组统计 | |
| if 'database_size' in confirmed_stats.columns: | |
| print("\n按数据库大小分组的准确率:") | |
| db_sizes = sorted(confirmed_stats['database_size'].unique()) | |
| for db_size in db_sizes: | |
| group = confirmed_stats[confirmed_stats['database_size'] == db_size] | |
| group_total = len(group) | |
| group_correct = len(group[group['correct'] == True]) | |
| group_accuracy = group_correct / group_total * 100 if group_total > 0 else 0 | |
| print(f" 数据库大小{db_size}: {group_accuracy:.2f}% ({group_correct}/{group_total})") | |
| # 按植物图片数量分组统计 | |
| if 'plant_image_count' in confirmed_stats.columns: | |
| print("\n按植物图片数量分组的准确率:") | |
| plant_counts = sorted(confirmed_stats['plant_image_count'].unique()) | |
| for count in plant_counts: | |
| group = confirmed_stats[confirmed_stats['plant_image_count'] == count] | |
| group_total = len(group) | |
| group_correct = len(group[group['correct'] == True]) | |
| group_accuracy = group_correct / group_total * 100 if group_total > 0 else 0 | |
| print(f" 植物图片数{count}: {group_accuracy:.2f}% ({group_correct}/{group_total})") | |
| # 新增:按同种植物组数分组统计 | |
| if 'species_group_count' in confirmed_stats.columns: | |
| print("\n按同种植物组数分组的准确率:") | |
| group_counts = sorted(confirmed_stats['species_group_count'].unique()) | |
| for count in group_counts: | |
| group = confirmed_stats[confirmed_stats['species_group_count'] == count] | |
| group_total = len(group) | |
| group_correct = len(group[group['correct'] == True]) | |
| group_accuracy = group_correct / group_total * 100 if group_total > 0 else 0 | |
| print(f" 同种组数{count}: {group_accuracy:.2f}% ({group_correct}/{group_total})") | |
| def clear_temp_folder(self): | |
| """清空临时文件夹""" | |
| temp_dir = "imgTemp" | |
| if not os.path.exists(temp_dir): | |
| print(f"临时文件夹 {temp_dir} 不存在") | |
| return False | |
| try: | |
| # 获取目录中的所有文件 | |
| files = [f for f in os.listdir(temp_dir) if os.path.isfile(os.path.join(temp_dir, f))] | |
| if not files: | |
| print("临时文件夹为空") | |
| return True | |
| # 删除所有文件 | |
| for file in files: | |
| file_path = os.path.join(temp_dir, file) | |
| os.remove(file_path) | |
| print(f"已清空临时文件夹,删除了 {len(files)} 个文件") | |
| return True | |
| except Exception as e: | |
| print(f"清空临时文件夹时出错: {e}") | |
| return False | |
| def validate_database(self): | |
| """验证数据库完整性""" | |
| print("验证数据库完整性...") | |
| for plant_id, data in self.database.items(): | |
| feature_count = len(data["features"]) | |
| print(f"植物 {plant_id}: {feature_count} 个特征向量, {data['image_count']} 张图片") | |
| # 检查特征向量维度是否一致 | |
| if feature_count > 0: | |
| first_dim = len(data["features"][0]) | |
| for i, feature in enumerate(data["features"]): | |
| if len(feature) != first_dim: | |
| print(f" 警告: 特征向量 {i} 的维度不一致") | |
| # 检查特征向量是否已正确加载 | |
| print("\n检查特征向量格式:") | |
| for plant_id, data in self.database.items(): | |
| if data["features"]: | |
| sample_feature = np.array(data["features"][0]) | |
| print(f"植物 {plant_id}: 特征向量形状 {sample_feature.shape}, 数据类型 {sample_feature.dtype}") | |
| def reset_database(self): | |
| """重置数据库(谨慎使用)""" | |
| confirm = input("确定要重置数据库吗?所有数据将丢失!(y/n): ") | |
| if confirm.lower() == 'y': | |
| self.database = {} | |
| self.save_database() | |
| self.enrolled_file_hashes = set() | |
| self.save_enrolled_hashes() | |
| print("数据库已重置") | |
| def test_feature_consistency(self, image_path): | |
| """测试特征提取的一致性""" | |
| if not os.path.exists(image_path): | |
| print(f"图像不存在: {image_path}") | |
| return | |
| # 提取特征多次,检查一致性 | |
| features = [] | |
| for i in range(3): # 减少测试次数以加快速度 | |
| try: | |
| feature, _ = self.extract_features([image_path], check_duplicates=False) | |
| if len(feature) > 0: | |
| features.append(feature[0]) | |
| else: | |
| print(f"第 {i + 1} 次特征提取失败: 没有提取到特征") | |
| except Exception as e: | |
| print(f"第 {i + 1} 次特征提取失败: {e}") | |
| continue | |
| if len(features) < 2: | |
| print("特征提取失败,无法进行一致性测试") | |
| return | |
| # 计算特征之间的相似度 | |
| similarities = [] | |
| for i in range(len(features)): | |
| for j in range(i + 1, len(features)): | |
| dot_product = np.dot(features[i], features[j]) | |
| norm_a = np.linalg.norm(features[i]) | |
| norm_b = np.linalg.norm(features[j]) | |
| if norm_a < 1e-10 or norm_b < 1e-10: | |
| similarity = 0.0 | |
| else: | |
| similarity = dot_product / (norm_a * norm_b) | |
| similarities.append(similarity) | |
| avg_similarity = np.mean(similarities) | |
| print(f"特征提取一致性测试: 平均相似度 = {avg_similarity:.6f}") | |
| # 打印特征向量的前几个值 | |
| print("特征向量前5个值:") | |
| for i, feature in enumerate(features): | |
| print(f" 第 {i + 1} 次: {feature[:5]}") | |
| return avg_similarity | |
| # 交互式命令行界面 | |
| def main(): | |
| # 确保临时目录存在 | |
| os.makedirs("imgTemp", exist_ok=True) | |
| os.makedirs("db", exist_ok=True) | |
| # 询问模型路径 | |
| model_path = input("请输入模型文件路径(直接回车使用默认路径): ").strip() | |
| if model_path == "": | |
| model_path = "db/plant_model.h5" | |
| system = PlantFaceRecognitionSystem(model_path=model_path) | |
| print("\n=== 植物刷脸机系统 ===") | |
| print("命令: enroll, recognize, confirm, stats, list, accuracy, clear, validate, reset, test, delete, export, exit") | |
| print("提示: 请将图像放在 imgTemp 目录中") | |
| while True: | |
| command = input("\n请输入命令: ").strip().lower() | |
| if command == "exit": | |
| print("保存数据并退出系统...") | |
| system.save_database() | |
| system.save_stats() | |
| system.save_enrolled_hashes() | |
| break | |
| elif command == "enroll" or command == "en": | |
| plant_id = input("请输入植物编号: ").strip() | |
| image_dir = "imgTemp" | |
| if not os.path.exists(image_dir): | |
| print("错误: 目录不存在") | |
| continue | |
| # 获取目录中的所有图像 | |
| image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff'] | |
| image_paths = [ | |
| os.path.join(image_dir, f) for f in os.listdir(image_dir) | |
| if any(f.lower().endswith(ext) for ext in image_extensions) | |
| ] | |
| if not image_paths: | |
| print("错误: 目录中没有找到图像文件") | |
| continue | |
| success, message = system.enroll_plant(plant_id, image_paths) | |
| print(message) | |
| # 询问是否删除已处理的图像 | |
| if success and input("是否删除已处理的图像? (y/n): ").strip().lower() == 'y': | |
| for img_path in image_paths: | |
| try: | |
| os.remove(img_path) | |
| except: | |
| pass | |
| print("已删除图像文件") | |
| elif command == "recognize" or command == "rec": | |
| image_dir = "imgTemp" | |
| if not os.path.exists(image_dir): | |
| print("错误: 目录不存在") | |
| continue | |
| # 获取目录中的所有图像 | |
| image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff'] | |
| image_paths = [ | |
| os.path.join(image_dir, f) for f in os.listdir(image_dir) | |
| if any(f.lower().endswith(ext) for ext in image_extensions) | |
| ] | |
| if not image_paths: | |
| print("错误: 目录中没有找到图像文件") | |
| continue | |
| plant_id, message, confidence, all_scores = system.recognize_plant(image_paths) #can get all the similarities, but we don't print in this version | |
| print(message) | |
| elif command == "confirm" or command == "con": | |
| if system.last_recognition_index is None: | |
| print("错误: 没有最近的识别记录,请先进行识别") | |
| continue | |
| # 显示最近识别记录 | |
| last_record = system.stats.iloc[system.last_recognition_index] | |
| print(f"最近识别记录 (索引 {system.last_recognition_index}):") | |
| print(f" 识别结果: {last_record['recognized_id']}") | |
| print(f" 置信度: {last_record['confidence']:.3f}") | |
| print(f" 图片数量: {last_record['num_images']}") | |
| print(f" 同种植物组数: {last_record['species_group_count']}") | |
| feedback = input("识别是否正确? (y/n): ").strip().lower() | |
| if feedback == 'y': | |
| # 询问是否将图像加入到数据库中 | |
| enroll_choice = input("是否将这些图像加入到数据库中? (y/n): ").strip().lower() | |
| enroll_images = (enroll_choice == 'y') | |
| # 使用识别结果作为植物编号 | |
| plant_id = system.last_recognition_result | |
| print(f"使用识别结果作为植物编号: {plant_id}") | |
| success, message = system.confirm_recognition(plant_id, True, enroll_images) | |
| else: | |
| # 获取正确的植物编号 | |
| plant_id = input("请输入正确的植物编号: ") | |
| # 询问是否将图像加入到数据库中 | |
| enroll_choice = input("是否将这些图像加入到数据库中? (y/n): ").strip().lower() | |
| enroll_images = (enroll_choice == 'y') | |
| success, message = system.confirm_recognition(plant_id, False, enroll_images) | |
| print(message) | |
| elif command == "stats": | |
| print("\n=== 识别统计 ===") | |
| print(system.stats) | |
| elif command == "check_model": | |
| system.check_model_files() | |
| elif command == "list" or command == "ls": | |
| system.list_plants() | |
| elif command == "accuracy": | |
| system.show_accuracy_stats() | |
| elif command == "clear" or command == "cl": | |
| success = system.clear_temp_folder() | |
| if success: | |
| print("临时文件夹已清空") | |
| elif command == "validate": | |
| system.validate_database() | |
| elif command == "reset": | |
| system.reset_database() | |
| elif command == "test": | |
| # 尝试使用imgTemp目录中的第一个图像进行测试 | |
| image_dir = "imgTemp" | |
| if os.path.exists(image_dir): | |
| image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff'] | |
| image_files = [ | |
| f for f in os.listdir(image_dir) | |
| if any(f.lower().endswith(ext) for ext in image_extensions) | |
| ] | |
| if image_files: | |
| image_path = os.path.join(image_dir, image_files[0]) | |
| print(f"使用图像 {image_files[0]} 进行测试") | |
| system.test_feature_consistency(image_path) | |
| else: | |
| print("imgTemp目录中没有图像文件") | |
| else: | |
| print("imgTemp目录不存在") | |
| elif command == "delete": | |
| plant_id = input("请输入要删除的植物编号: ").strip() | |
| success, message = system.delete_plant(plant_id) | |
| print(message) | |
| elif command == "export": | |
| success = system.export_database_info() | |
| if success: | |
| print("数据库信息导出成功") | |
| else: | |
| print( | |
| "未知命令,请使用: enroll, recognize, confirm, stats, list, accuracy, clear, validate, reset, test, delete, export, exit") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment