Created
November 29, 2025 06:16
-
-
Save MarchLiu/f51f472969ab37ec2b22d1e38fb57f75 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Code for recognizing ten groups of photos. The codes for recognizing two, three, and five groups of photos follows the same structure, differing only in the number of groups. | |
| # 1. 自动数据分割:从leaves文件夹自动分割train/val/test | |
| # 2. 独立平均图片生成类 | |
| import os | |
| import sys | |
| import warnings | |
| import random | |
| import shutil | |
| from sklearn.model_selection import train_test_split | |
| # 设置环境变量减少日志 | |
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | |
| os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' | |
| # 处理可能的导入错误 | |
| try: | |
| import numpy as np | |
| except ImportError: | |
| print("Error: NumPy not installed. Please install NumPy 1.19.5") | |
| sys.exit(1) | |
| try: | |
| import tensorflow as tf | |
| except ImportError: | |
| print("Error: TensorFlow not installed. Please install TensorFlow 2.5.0") | |
| sys.exit(1) | |
| # 尝试导入其他库,提供友好错误提示 | |
| optional_libs = ['pandas', 'umap', 'PIL'] | |
| missing_libs = [] | |
| for lib in optional_libs: | |
| try: | |
| __import__(lib) | |
| except ImportError: | |
| missing_libs.append(lib) | |
| if missing_libs: | |
| print(f"Warning: Missing optional libraries: {', '.join(missing_libs)}") | |
| print("Some advanced features will be disabled.") | |
| DISABLE_ADVANCED_FEATURES = True | |
| else: | |
| DISABLE_ADVANCED_FEATURES = False | |
| import pandas as pd | |
| from PIL import Image | |
| import umap | |
| from sklearn.manifold import TSNE | |
| # 忽略特定警告 | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| warnings.filterwarnings("ignore", category=FutureWarning) | |
| import tensorflow as tf | |
| from tensorflow.keras.preprocessing.image import ImageDataGenerator | |
| from tensorflow.keras.applications import MobileNetV2 | |
| from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout | |
| from tensorflow.keras.models import Sequential, load_model, Model | |
| from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| from sklearn.metrics import confusion_matrix, classification_report, accuracy_score, roc_curve, auc | |
| from sklearn.manifold import TSNE | |
| from sklearn.decomposition import PCA | |
| import umap | |
| from umap import UMAP | |
| from itertools import cycle | |
| from PIL import Image | |
| import glob | |
| # 处理grpcio兼容性警告 | |
| try: | |
| from grpc._cython import cygrpc | |
| except ImportError: | |
| pass # 忽略此错误 | |
| from average_image_generator import ClassAverageImageGenerator | |
| # 设置全局字体为英文 | |
| plt.rcParams['font.family'] = 'sans-serif' | |
| plt.rcParams['font.sans-serif'] = ['Arial', 'Helvetica', 'DejaVu Sans'] | |
| plt.rcParams['axes.unicode_minus'] = False | |
| plt.rcParams['savefig.format'] = 'pdf' # 保存为矢量图 | |
| plt.rcParams['pdf.fonttype'] = 42 # 确保文本可编辑 | |
| plt.rcParams['ps.fonttype'] = 42 # 确保文本可编辑 | |
| # 禁用GPU(如果需要) | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "-1" | |
| os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" | |
| tf.config.set_visible_devices([], 'GPU') | |
| # 验证设备设置 | |
| print("Available devices:", tf.config.list_physical_devices()) | |
| print("GPU available:", tf.config.list_physical_devices('GPU')) | |
| # ---------------------- | |
| # 0. 设置工作目录 | |
| # ---------------------- | |
| os.chdir(".../working/dir") | |
| # ---------------------- | |
| # 新增:自动数据分割类 | |
| # ---------------------- | |
| class DataSplitter: | |
| def __init__(self, source_dir="leaves", output_base_dir="split_data", | |
| split_ratios=[0.6, 0.2, 0.2], random_seed=42, clean_output=True): | |
| """ | |
| 初始化数据分割器 | |
| Args: | |
| source_dir: 源数据目录,包含各个类别的子文件夹 | |
| output_base_dir: 输出目录 | |
| split_ratios: 分割比例 [train, val, test] | |
| random_seed: 随机种子 | |
| clean_output: 是否在分割前清空输出目录 | |
| """ | |
| self.source_dir = source_dir | |
| self.output_base_dir = output_base_dir | |
| self.split_ratios = split_ratios | |
| self.random_seed = random_seed | |
| self.clean_output = clean_output | |
| # 验证分割比例 | |
| if abs(sum(split_ratios) - 1.0) > 0.01: | |
| raise ValueError(f"Split ratios must sum to 1.0, got {sum(split_ratios)}") | |
| # 设置输出目录 | |
| self.train_dir = os.path.join(output_base_dir, "train") | |
| self.val_dir = os.path.join(output_base_dir, "val") | |
| self.test_dir = os.path.join(output_base_dir, "test") | |
| # 设置随机种子 | |
| random.seed(random_seed) | |
| np.random.seed(random_seed) | |
| def split_data(self): | |
| """执行数据分割""" | |
| print("开始自动数据分割...") | |
| print(f"源目录: {self.source_dir}") | |
| print(f"输出目录: {self.output_base_dir}") | |
| print(f"分割比例: Train {self.split_ratios[0] * 100}%, " | |
| f"Val {self.split_ratios[1] * 100}%, Test {self.split_ratios[2] * 100}%") | |
| # 清空或创建输出目录 | |
| self._prepare_output_directories() | |
| # 获取所有类别 | |
| class_folders = [f for f in os.listdir(self.source_dir) | |
| if os.path.isdir(os.path.join(self.source_dir, f))] | |
| class_folders.sort() | |
| print(f"发现 {len(class_folders)} 个类别: {class_folders}") | |
| total_files = 0 | |
| detailed_stats = {} | |
| for class_folder in class_folders: | |
| class_path = os.path.join(self.source_dir, class_folder) | |
| # 获取该类别的所有图片文件 | |
| image_files = self._get_image_files(class_path) | |
| if not image_files: | |
| print(f"警告: 类别 {class_folder} 中没有找到图片文件") | |
| continue | |
| # 随机打乱文件顺序 | |
| random.shuffle(image_files) | |
| # 计算分割点 | |
| n_total = len(image_files) | |
| n_train = int(n_total * self.split_ratios[0]) | |
| n_val = int(n_total * self.split_ratios[1]) | |
| # 分割文件 | |
| train_files = image_files[:n_train] | |
| val_files = image_files[n_train:n_train + n_val] | |
| test_files = image_files[n_train + n_val:] | |
| # 复制文件到相应目录 | |
| self._copy_files(train_files, class_folder, self.train_dir) | |
| self._copy_files(val_files, class_folder, self.val_dir) | |
| self._copy_files(test_files, class_folder, self.test_dir) | |
| total_files += n_total | |
| # 记录详细统计 | |
| detailed_stats[class_folder] = { | |
| 'total': n_total, | |
| 'train': len(train_files), | |
| 'val': len(val_files), | |
| 'test': len(test_files) | |
| } | |
| print(f"类别 {class_folder}: {n_total} 张图片 -> " | |
| f"Train: {len(train_files)}, Val: {len(val_files)}, Test: {len(test_files)}") | |
| print(f"\n数据分割完成! 总共处理 {total_files} 张图片") | |
| # 验证分割结果 | |
| self._verify_split_results(detailed_stats) | |
| return self.train_dir, self.val_dir, self.test_dir | |
| def _prepare_output_directories(self): | |
| """准备输出目录(清空或创建)""" | |
| if self.clean_output and os.path.exists(self.output_base_dir): | |
| print(f"清空输出目录: {self.output_base_dir}") | |
| shutil.rmtree(self.output_base_dir) | |
| # 创建输出目录 | |
| os.makedirs(self.train_dir, exist_ok=True) | |
| os.makedirs(self.val_dir, exist_ok=True) | |
| os.makedirs(self.test_dir, exist_ok=True) | |
| print(f"创建输出目录: {self.output_base_dir}") | |
| def _get_image_files(self, class_path): | |
| """获取目录中的所有图片文件""" | |
| image_files = [] | |
| for ext in ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tiff']: | |
| image_files.extend(glob.glob(os.path.join(class_path, ext))) | |
| image_files.extend(glob.glob(os.path.join(class_path, ext.upper()))) | |
| return image_files | |
| def _copy_files(self, file_list, class_name, target_dir): | |
| """复制文件到目标目录""" | |
| class_target_dir = os.path.join(target_dir, class_name) | |
| os.makedirs(class_target_dir, exist_ok=True) | |
| for file_path in file_list: | |
| filename = os.path.basename(file_path) | |
| target_path = os.path.join(class_target_dir, filename) | |
| shutil.copy2(file_path, target_path) | |
| def _verify_split_results(self, detailed_stats): | |
| """验证分割结果""" | |
| print("\n=== 分割结果验证 ===") | |
| # 验证每个分割目录的文件数量 | |
| for split_name, split_dir in [('train', self.train_dir), | |
| ('val', self.val_dir), | |
| ('test', self.test_dir)]: | |
| actual_count = 0 | |
| for class_folder in os.listdir(split_dir): | |
| class_path = os.path.join(split_dir, class_folder) | |
| if os.path.isdir(class_path): | |
| files = [f for f in os.listdir(class_path) | |
| if os.path.isfile(os.path.join(class_path, f))] | |
| actual_count += len(files) | |
| expected_count = sum(stats[split_name] for stats in detailed_stats.values()) | |
| print(f"{split_name}: 预期 {expected_count}, 实际 {actual_count}") | |
| if actual_count != expected_count: | |
| print(f"警告: {split_name} 目录文件数量不匹配!") | |
| # ---------------------- | |
| # 1. 自动数据分割 | |
| # ---------------------- | |
| # 配置分割参数 | |
| split_ratios = [0.6, 0.2, 0.2] # train: 60%, val: 20%, test: 20% | |
| source_dir = "leaves" # 原始数据目录 | |
| output_base_dir = "split_data" # 分割后数据输出目录 | |
| # 执行数据分割 | |
| splitter = DataSplitter( | |
| source_dir=source_dir, | |
| output_base_dir=output_base_dir, | |
| split_ratios=split_ratios, | |
| random_seed=42, | |
| clean_output=True | |
| ) | |
| train_dir, val_dir, test_dir = splitter.split_data() | |
| # ---------------------- | |
| # 2. 设置数据路径(使用自动分割的结果) | |
| # ---------------------- | |
| model_save_path = "leaf_classifier_10class.h5" # 模型保存路径 | |
| output_dir = "analysis_results" # 分析结果输出目录 | |
| # 创建输出目录 | |
| os.makedirs(output_dir, exist_ok=True) | |
| # 检查路径是否存在 | |
| if not os.path.exists(train_dir): | |
| raise FileNotFoundError(f"Training directory {train_dir} not found!") | |
| if not os.path.exists(val_dir): | |
| raise FileNotFoundError(f"Validation directory {val_dir} not found!") | |
| if not os.path.exists(test_dir): | |
| raise FileNotFoundError(f"Test directory {test_dir} not found!") | |
| # ---------------------- | |
| # 3. 创建数据生成器 - 修改为支持十分类 | |
| # ---------------------- | |
| IMG_HEIGHT = 224 # 图像高度 | |
| IMG_WIDTH = 224 # 图像宽度 | |
| BATCH_SIZE = 32 | |
| NUM_CLASSES = 10 # 10个类别 | |
| # 数据增强配置(适当增强以应对更多类别) | |
| train_datagen = ImageDataGenerator( | |
| rescale=1.0 / 255, | |
| rotation_range=8, | |
| width_shift_range=0.13, | |
| height_shift_range=0.13, | |
| zoom_range=0.1, #used to be 0.1 | |
| channel_shift_range=15, | |
| shear_range=0.1, | |
| fill_mode='reflect', | |
| cval=0.0, | |
| vertical_flip=False, | |
| horizontal_flip=False, | |
| brightness_range=[0.9, 1.1] | |
| ) | |
| val_datagen = ImageDataGenerator(rescale=1.0 / 255) | |
| test_datagen = ImageDataGenerator(rescale=1.0 / 255) | |
| # 创建主生成器(用于训练) | |
| train_generator = train_datagen.flow_from_directory( | |
| train_dir, | |
| target_size=(IMG_HEIGHT, IMG_WIDTH), | |
| batch_size=BATCH_SIZE, | |
| class_mode='categorical', | |
| shuffle=True # 训练时打乱数据很重要 | |
| ) | |
| val_generator = val_datagen.flow_from_directory( | |
| val_dir, | |
| target_size=(IMG_HEIGHT, IMG_WIDTH), | |
| batch_size=BATCH_SIZE, | |
| class_mode='categorical', | |
| shuffle=False | |
| ) | |
| test_generator = test_datagen.flow_from_directory( | |
| test_dir, | |
| target_size=(IMG_HEIGHT, IMG_WIDTH), | |
| batch_size=BATCH_SIZE, | |
| class_mode='categorical', | |
| shuffle=False | |
| ) | |
| # 创建专用的分析生成器(不打乱顺序,用于后处理) | |
| print("\n创建专用的分析生成器...") | |
| analysis_train_generator = train_datagen.flow_from_directory( | |
| train_dir, | |
| target_size=(IMG_HEIGHT, IMG_WIDTH), | |
| batch_size=BATCH_SIZE, | |
| class_mode='categorical', | |
| shuffle=False, # 关键:不打乱顺序 | |
| seed=42 | |
| ) | |
| analysis_val_generator = val_datagen.flow_from_directory( | |
| val_dir, | |
| target_size=(IMG_HEIGHT, IMG_WIDTH), | |
| batch_size=BATCH_SIZE, | |
| class_mode='categorical', | |
| shuffle=False, | |
| seed=42 | |
| ) | |
| analysis_test_generator = test_datagen.flow_from_directory( | |
| test_dir, | |
| target_size=(IMG_HEIGHT, IMG_WIDTH), | |
| batch_size=BATCH_SIZE, | |
| class_mode='categorical', | |
| shuffle=False, | |
| seed=42 | |
| ) | |
| # 确保分析生成器与主生成器有相同的类别映射 | |
| assert analysis_train_generator.class_indices == train_generator.class_indices | |
| assert analysis_val_generator.class_indices == val_generator.class_indices | |
| assert analysis_test_generator.class_indices == test_generator.class_indices | |
| print("分析生成器创建完成,类别映射一致") | |
| # 获取类别名称和索引 | |
| class_names = list(train_generator.class_indices.keys()) | |
| class_indices = train_generator.class_indices | |
| print(f"Class indices: {class_indices}") | |
| print(f"Class names: {class_names}") | |
| # ---------------------- | |
| # 4. 构建模型 - 修改为十分类(增加模型复杂度) | |
| # ---------------------- | |
| def build_model(): | |
| # 使用预训练模型作为特征提取器 | |
| base_model = MobileNetV2( | |
| input_shape=(IMG_HEIGHT, IMG_WIDTH, 3), | |
| include_top=False, | |
| weights='imagenet' | |
| ) | |
| base_model.trainable = False | |
| # 添加全局平均池化层以提取特征 | |
| x = base_model.output | |
| x = GlobalAveragePooling2D()(x) | |
| # 保存特征提取模型 | |
| feature_model = Model(inputs=base_model.input, outputs=x) | |
| # 分类头 | |
| y = Dense(1024, activation='relu')(x) | |
| y = Dropout(0.5)(y) | |
| y = Dense(512, activation='relu')(y) | |
| y = Dropout(0.3)(y) | |
| y = Dense(256, activation='relu')(y) | |
| y = Dense(128, activation='relu')(y) | |
| predictions = Dense(NUM_CLASSES, activation='softmax')(y) | |
| # 完整模型 | |
| model = Model(inputs=base_model.input, outputs=predictions) | |
| model.compile( | |
| optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4), | |
| loss='categorical_crossentropy', | |
| metrics=['accuracy'] | |
| ) | |
| return model, feature_model # 返回分类模型和特征提取模型 | |
| model, feature_model = build_model() | |
| model.summary() # 打印模型结构 | |
| # ---------------------- | |
| # 5. 训练模型 | |
| # ---------------------- | |
| # 回调函数配置(增加耐心值) | |
| callbacks = [ | |
| EarlyStopping(patience=10, monitor='val_loss', restore_best_weights=True), | |
| ModelCheckpoint(model_save_path, save_best_only=True) | |
| ] | |
| # 开始训练 | |
| print("\nStarting training for 10-class classification...") | |
| history = model.fit( | |
| train_generator, | |
| workers=32, | |
| epochs=512, # 增加最大epoch数 | |
| validation_data=val_generator, | |
| callbacks=callbacks | |
| ) | |
| # ---------------------- | |
| # 6. 评估模型 - 英文标签 | |
| # ---------------------- | |
| # 加载最佳模型 | |
| best_model = load_model(model_save_path) | |
| best_model.trainable = False # 确保不更新权重 | |
| # 在验证集上评估 | |
| loss, accuracy = best_model.evaluate(val_generator) | |
| print(f"\nValidation accuracy: {accuracy * 100:.2f}%") | |
| # 绘制训练曲线并保存为矢量图 | |
| plt.figure(figsize=(12, 5)) | |
| plt.subplot(1, 2, 1) | |
| plt.plot(history.history['accuracy'], label='Training Accuracy') | |
| plt.plot(history.history['val_accuracy'], label='Validation Accuracy') | |
| plt.title('Model Accuracy (10 Classes)') | |
| plt.ylabel('Accuracy') | |
| plt.xlabel('Epoch') | |
| plt.legend(loc='lower right') | |
| plt.subplot(1, 2, 2) | |
| plt.plot(history.history['loss'], label='Training Loss') | |
| plt.plot(history.history['val_loss'], label='Validation Loss') | |
| plt.title('Model Loss (10 Classes)') | |
| plt.ylabel('Loss') | |
| plt.xlabel('Epoch') | |
| plt.legend(loc='upper right') | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(output_dir, 'training_history_10class.pdf')) | |
| plt.savefig(os.path.join(output_dir, 'training_history_10class.png'), dpi=300) | |
| plt.show() | |
| # 获取验证集所有预测结果和真实标签 | |
| print("\nGenerating predictions...") | |
| val_preds = best_model.predict(val_generator) | |
| val_pred_classes = np.argmax(val_preds, axis=1) # 获取预测类别 | |
| val_true_classes = val_generator.classes # 真实类别索引 | |
| # 1. 整体评估指标 | |
| print("\n===== Overall Metrics =====") | |
| print(f"Validation accuracy: {accuracy_score(val_true_classes, val_pred_classes) * 100:.2f}%") | |
| # 2. 混淆矩阵 | |
| cm = confusion_matrix(val_true_classes, val_pred_classes) | |
| print("\nConfusion Matrix:") | |
| print(cm) | |
| # 可视化混淆矩阵(调整以适应更多类别) | |
| plt.figure(figsize=(14, 12)) | |
| plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) | |
| plt.title('Confusion Matrix (10 Classes)', fontsize=16) | |
| plt.colorbar() | |
| tick_marks = np.arange(len(class_names)) | |
| plt.xticks(tick_marks, class_names, rotation=45, ha='right', fontsize=10) | |
| plt.yticks(tick_marks, class_names, fontsize=10) | |
| # 添加数值标签 | |
| thresh = cm.max() / 2. | |
| for i in range(cm.shape[0]): | |
| for j in range(cm.shape[1]): | |
| plt.text(j, i, format(cm[i, j], 'd'), | |
| horizontalalignment="center", | |
| verticalalignment="center", | |
| color="white" if cm[i, j] > thresh else "black", | |
| fontsize=9) | |
| plt.xlabel('Predicted Label', fontsize=12) | |
| plt.ylabel('True Label', fontsize=12) | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(output_dir, 'confusion_matrix_10class.pdf')) | |
| plt.savefig(os.path.join(output_dir, 'confusion_matrix_10class.png'), dpi=300, bbox_inches='tight') | |
| plt.show() | |
| # 3. 分类报告 | |
| print("\nClassification Report:") | |
| report = classification_report( | |
| val_true_classes, | |
| val_pred_classes, | |
| target_names=class_names, | |
| output_dict=True # 输出为字典格式便于保存 | |
| ) | |
| print(pd.DataFrame(report).transpose()) | |
| # 保存分类报告 | |
| report_df = pd.DataFrame(report).transpose() | |
| report_df.to_csv(os.path.join(output_dir, 'classification_report.csv')) | |
| report_df.to_excel(os.path.join(output_dir, 'classification_report.xlsx')) | |
| # 4. 计算每个类别的准确率 | |
| print("\n===== Class-wise Performance =====") | |
| class_accuracies = [] | |
| class_recalls = [] | |
| class_precisions = [] | |
| for i, class_name in enumerate(class_names): | |
| # 找出当前类别的样本索引 | |
| class_indices = np.where(val_true_classes == i)[0] | |
| # 计算当前类别的准确率 | |
| correct = np.sum(val_pred_classes[class_indices] == i) | |
| total = len(class_indices) | |
| acc = correct / total if total > 0 else 0 | |
| class_accuracies.append(acc) | |
| # 计算召回率 | |
| recall = correct / total if total > 0 else 0 | |
| class_recalls.append(recall) | |
| # 计算精确率 | |
| predicted_indices = np.where(val_pred_classes == i)[0] | |
| precision = correct / len(predicted_indices) if len(predicted_indices) > 0 else 0 | |
| class_precisions.append(precision) | |
| print(f"Class '{class_name}':") | |
| print(f" Accuracy: {acc * 100:.2f}% ({correct}/{total})") | |
| print(f" Recall: {recall * 100:.2f}%") | |
| print(f" Precision: {precision * 100:.2f}%") | |
| # 可视化各类别性能(水平柱状图更适合更多类别) | |
| plt.figure(figsize=(12, 10)) | |
| y_pos = np.arange(len(class_names)) | |
| # 创建三个水平柱状图 | |
| plt.barh(y_pos - 0.25, class_accuracies, height=0.25, label='Accuracy', color='#1f77b4') | |
| plt.barh(y_pos, class_recalls, height=0.25, label='Recall', color='#ff7f0e') | |
| plt.barh(y_pos + 0.25, class_precisions, height=0.25, label='Precision', color='#2ca02c') | |
| plt.xlabel('Score', fontsize=12) | |
| plt.title('Class-wise Performance Metrics (10 Classes)', fontsize=16) | |
| plt.yticks(y_pos, class_names, fontsize=10) | |
| plt.xlim(0, 1.1) | |
| plt.legend(loc='lower right', fontsize=10) | |
| plt.grid(axis='x', linestyle='--', alpha=0.7) | |
| # 添加数值标签 | |
| for i, v in enumerate(class_accuracies): | |
| plt.text(v + 0.02, i - 0.25, f"{v:.2f}", color='black', fontsize=9) | |
| for i, v in enumerate(class_recalls): | |
| plt.text(v + 0.02, i, f"{v:.2f}", color='black', fontsize=9) | |
| for i, v in enumerate(class_precisions): | |
| plt.text(v + 0.02, i + 0.25, f"{v:.2f}", color='black', fontsize=9) | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(output_dir, 'class_performance_10class.pdf')) | |
| plt.savefig(os.path.join(output_dir, 'class_performance_10class.png'), dpi=300) | |
| plt.show() | |
| # 5. ROC曲线(十分类) | |
| plt.figure(figsize=(12, 10)) | |
| colors = cycle(['blue', 'red', 'green', 'purple', 'orange', | |
| 'cyan', 'magenta', 'lime', 'brown', 'pink']) | |
| # 计算每个类别的ROC曲线和AUC | |
| for i, color in zip(range(NUM_CLASSES), colors): | |
| # 获取当前类别的二分类标签 | |
| y_true_i = np.where(val_true_classes == i, 1, 0) | |
| y_score_i = val_preds[:, i] | |
| # 计算ROC曲线 | |
| fpr, tpr, _ = roc_curve(y_true_i, y_score_i) | |
| roc_auc = auc(fpr, tpr) | |
| # 绘制ROC曲线 | |
| plt.plot(fpr, tpr, color=color, lw=2, | |
| label=f'{class_names[i]} (AUC = {roc_auc:.2f})') | |
| plt.plot([0, 1], [0, 1], 'k--', lw=2) | |
| plt.xlim([0.0, 1.0]) | |
| plt.ylim([0.0, 1.05]) | |
| plt.xlabel('False Positive Rate', fontsize=12) | |
| plt.ylabel('True Positive Rate', fontsize=12) | |
| plt.title('ROC Curves for 10-Class Classification', fontsize=16) | |
| plt.legend(loc="lower right", fontsize=10) | |
| plt.grid(alpha=0.3) | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(output_dir, 'roc_curves_10class.pdf')) | |
| plt.savefig(os.path.join(output_dir, 'roc_curves_10class.png'), dpi=300) | |
| plt.show() | |
| # ---------------------- | |
| # 7. 在测试集上最终评估 | |
| # ---------------------- | |
| print("\n===== Test Set Evaluation =====") | |
| test_loss, test_accuracy = best_model.evaluate(test_generator) | |
| print(f"Test accuracy: {test_accuracy * 100:.2f}%") | |
| # 获取测试集所有预测结果和真实标签 | |
| test_preds = best_model.predict(test_generator) | |
| test_pred_classes = np.argmax(test_preds, axis=1) | |
| test_true_classes = test_generator.classes | |
| # 测试集混淆矩阵 | |
| test_cm = confusion_matrix(test_true_classes, test_pred_classes) | |
| print("\nTest Confusion Matrix:") | |
| print(test_cm) | |
| # 可视化测试集混淆矩阵 | |
| plt.figure(figsize=(14, 12)) | |
| plt.imshow(test_cm, interpolation='nearest', cmap=plt.cm.Blues) | |
| plt.title('Test Confusion Matrix (10 Classes)', fontsize=16) | |
| plt.colorbar() | |
| tick_marks = np.arange(len(class_names)) | |
| plt.xticks(tick_marks, class_names, rotation=45, ha='right', fontsize=10) | |
| plt.yticks(tick_marks, class_names, fontsize=10) | |
| thresh = test_cm.max() / 2. | |
| for i in range(test_cm.shape[0]): | |
| for j in range(test_cm.shape[1]): | |
| plt.text(j, i, format(test_cm[i, j], 'd'), | |
| horizontalalignment="center", | |
| verticalalignment="center", | |
| color="white" if test_cm[i, j] > thresh else "black", | |
| fontsize=9) | |
| plt.xlabel('Predicted Label', fontsize=12) | |
| plt.ylabel('True Label', fontsize=12) | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(output_dir, 'test_confusion_matrix_10class.pdf')) | |
| plt.savefig(os.path.join(output_dir, 'test_confusion_matrix_10class.png'), dpi=300, bbox_inches='tight') | |
| plt.show() | |
| # 测试集分类报告 | |
| print("\nTest Classification Report:") | |
| test_report = classification_report( | |
| test_true_classes, | |
| test_pred_classes, | |
| target_names=class_names, | |
| output_dict=True | |
| ) | |
| print(pd.DataFrame(test_report).transpose()) | |
| # 保存测试集分类报告 | |
| test_report_df = pd.DataFrame(test_report).transpose() | |
| test_report_df.to_csv(os.path.join(output_dir, 'test_classification_report.csv')) | |
| # 测试集各类别性能 | |
| print("\n===== Test Class-wise Performance =====") | |
| test_class_accuracies = [] | |
| for i, class_name in enumerate(class_names): | |
| class_indices = np.where(test_true_classes == i)[0] | |
| correct = np.sum(test_pred_classes[class_indices] == i) | |
| total = len(class_indices) | |
| acc = correct / total if total > 0 else 0 | |
| test_class_accuracies.append(acc) | |
| print(f"Class '{class_name}': Accuracy = {acc * 100:.2f}% ({correct}/{total})") | |
| # 可视化训练/验证/测试准确率对比 | |
| plt.figure(figsize=(10, 6)) | |
| x = np.arange(len(class_names)) | |
| width = 0.25 | |
| # 获取之前计算的验证集各类别准确率(class_accuracies) | |
| plt.bar(x - width, class_accuracies, width, label='Validation', color='#1f77b4') | |
| plt.bar(x, test_class_accuracies, width, label='Test', color='#ff7f0e') | |
| plt.ylabel('Accuracy', fontsize=12) | |
| plt.title('Validation vs Test Accuracy per Class', fontsize=14) | |
| plt.xticks(x, class_names, fontsize=10, rotation=30) | |
| plt.ylim(0, 1.1) | |
| plt.legend(fontsize=10) | |
| plt.grid(axis='y', linestyle='--', alpha=0.7) | |
| # 添加数值标签 | |
| for i, v in enumerate(class_accuracies): | |
| plt.text(i - width, v + 0.02, f"{v:.2f}", ha='center', fontsize=9) | |
| for i, v in enumerate(test_class_accuracies): | |
| plt.text(i, v + 0.02, f"{v:.2f}", ha='center', fontsize=9) | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(output_dir, 'validation_vs_test_accuracy.pdf')) | |
| plt.savefig(os.path.join(output_dir, 'validation_vs_test_accuracy.png'), dpi=300) | |
| plt.show() | |
| # 添加统计显著性检验 | |
| try: | |
| from scipy.stats import binomtest | |
| from statsmodels.stats.contingency_tables import mcnemar | |
| from scipy.stats import chi2_contingency | |
| # ... 统计检验函数保持不变 ... | |
| # 由于长度限制,这里省略了统计检验函数的完整代码 | |
| # 你可以从之前的版本中复制这些函数 | |
| except ImportError: | |
| print("\n警告:未找到scipy或statsmodels库,无法进行统计检验") | |
| print("请安装这些库以获得更全面的分析:") | |
| print("pip install scipy statsmodels") | |
| # ---------------------- | |
| # 8. 训练后概率聚类分析(保留,删除原始图片聚类) | |
| # ---------------------- | |
| print("\n===== Starting Advanced Analysis =====") | |
| # 提取预测概率(用于训练后聚类)使用专用生成器 | |
| def extract_predictions(generator, model, dataset_name): | |
| """提取预测概率(使用专用生成器)""" | |
| probabilities = [] | |
| true_labels = [] | |
| file_paths = [] | |
| # 重置生成器 | |
| generator.reset() | |
| # 分批处理 | |
| for i in range(len(generator)): | |
| if i % 10 == 0: | |
| print(f"Processing {dataset_name} batch {i + 1}/{len(generator)}") | |
| # 获取一批数据 | |
| batch_x, batch_y = generator[i] | |
| # 提取预测概率 | |
| batch_probs = model.predict(batch_x, verbose=0) | |
| probabilities.append(batch_probs) | |
| true_labels.append(np.argmax(batch_y, axis=1)) | |
| # 获取当前批次的文件路径 | |
| start_idx = i * generator.batch_size | |
| end_idx = start_idx + len(batch_y) | |
| current_filepaths = generator.filepaths[start_idx:end_idx] | |
| file_paths.extend(current_filepaths) | |
| # 合并结果 | |
| probabilities = np.vstack(probabilities) if probabilities else np.array([]) | |
| true_labels = np.concatenate(true_labels) if true_labels else np.array([]) | |
| # 创建DataFrame | |
| df = pd.DataFrame(probabilities, columns=[f"prob_class_{i}" for i in range(NUM_CLASSES)]) | |
| df['true_label'] = true_labels | |
| df['predicted_label'] = np.argmax(probabilities, axis=1) if probabilities.size > 0 else -1 | |
| df['file_path'] = file_paths | |
| df['dataset'] = dataset_name | |
| return df | |
| # 为所有数据集提取预测概率(使用专用生成器) | |
| print("\nExtracting predictions for training set using analysis generator...") | |
| train_df = extract_predictions(analysis_train_generator, best_model, "train") | |
| print("Extracting predictions for validation set using analysis generator...") | |
| val_df = extract_predictions(analysis_val_generator, best_model, "val") | |
| print("Extracting predictions for test set using analysis generator...") | |
| test_df = extract_predictions(analysis_test_generator, best_model, "test") | |
| # 合并所有数据 | |
| all_df = pd.concat([train_df, val_df, test_df]) | |
| # 保存所有预测概率到Excel | |
| all_df.to_excel(os.path.join(output_dir, 'all_predictions.xlsx'), index=False) | |
| print(f"Saved all predictions to {os.path.join(output_dir, 'all_predictions.xlsx')}") | |
| # 提取概率向量(10维) | |
| prob_vectors = all_df[[f'prob_class_{i}' for i in range(NUM_CLASSES)]].values | |
| file_paths = all_df['file_path'].values | |
| all_labels = all_df['true_label'].values | |
| all_dataset_types = all_df['dataset'].values | |
| # 训练后概率聚类分析 | |
| def perform_prob_clustering(prob_vectors, labels, dataset_types, file_paths, method='tsne'): | |
| """对预测概率执行聚类分析""" | |
| # 选择降维方法 | |
| if method == 'pca': | |
| try: | |
| from sklearn.decomposition import PCA | |
| reducer = PCA(n_components=2, random_state=42) | |
| method_name = 'PCA' | |
| except ImportError as e: | |
| print(f"PCA not available: {str(e)}") | |
| return None | |
| elif method == 'umap': | |
| try: | |
| from umap import UMAP | |
| reducer = UMAP(n_components=2, random_state=42) | |
| method_name = 'UMAP' | |
| except Exception as e: | |
| print(f"UMAP error: {str(e)}. Using PCA instead.") | |
| try: | |
| from sklearn.decomposition import PCA | |
| reducer = PCA(n_components=2, random_state=42) | |
| method_name = 'PCA (instead of UMAP)' | |
| except ImportError: | |
| print("PCA not available. Using t-SNE instead.") | |
| try: | |
| from sklearn.manifold import TSNE | |
| reducer = TSNE(n_components=2, perplexity=30, random_state=42) | |
| method_name = 't-SNE (instead of UMAP)' | |
| except ImportError: | |
| print("t-SNE not available. Dimensionality reduction skipped.") | |
| return None | |
| else: # 默认使用t-SNE | |
| try: | |
| from sklearn.manifold import TSNE | |
| reducer = TSNE(n_components=2, perplexity=30, random_state=42) | |
| method_name = 't-SNE' | |
| except ImportError: | |
| print("t-SNE not available. Using PCA instead.") | |
| try: | |
| from sklearn.decomposition import PCA | |
| reducer = PCA(n_components=2, random_state=42) | |
| method_name = 'PCA (instead of t-SNE)' | |
| except ImportError: | |
| print("PCA not available. Dimensionality reduction skipped.") | |
| return None | |
| # 执行降维 | |
| print(f"Running {method_name} on probability vectors...") | |
| try: | |
| reduced_features = reducer.fit_transform(prob_vectors) | |
| except Exception as e: | |
| print(f"Error during dimensionality reduction: {str(e)}") | |
| return None | |
| # 创建结果DataFrame | |
| result_df = pd.DataFrame({ | |
| 'file_path': file_paths, | |
| 'true_label': labels, | |
| 'predicted_label': np.argmax(prob_vectors, axis=1), | |
| 'dataset': dataset_types, | |
| f'prob_{method_name}_X': reduced_features[:, 0], | |
| f'prob_{method_name}_Y': reduced_features[:, 1] | |
| }) | |
| # 保存CSV | |
| csv_filename = os.path.join(output_dir, f'prob_{method_name}_coordinates.csv') | |
| result_df.to_csv(csv_filename, index=False) | |
| print(f"Saved probability {method_name} coordinates to {csv_filename}") | |
| # 可视化 | |
| plt.figure(figsize=(12, 10)) | |
| # 使用与原始聚类相同的颜色方案 | |
| class_colors = [ | |
| 'blue', 'red', 'green', 'purple', 'orange', | |
| 'cyan', 'magenta', 'lime', 'brown', 'pink' | |
| ] | |
| # 数据集标记样式 | |
| markers = { | |
| 'train': 'o', # 圆圈 | |
| 'val': 's', # 方块 | |
| 'test': '^' # 三角 | |
| } | |
| # 为每个类别创建图例条目 | |
| legend_handles = [] | |
| # 为每个类别和数据集类型创建散点图 | |
| for class_idx in range(NUM_CLASSES): | |
| color = class_colors[class_idx] | |
| for dataset_type in ['train', 'val', 'test']: | |
| # 选择当前类别和数据集类型的点 | |
| mask = (labels == class_idx) & (dataset_types == dataset_type) | |
| if np.sum(mask) > 0: | |
| # 设置标记形状 | |
| marker = markers[dataset_type] | |
| # 绘制点 | |
| scatter = plt.scatter( | |
| reduced_features[mask, 0], | |
| reduced_features[mask, 1], | |
| color=color, | |
| marker=marker, | |
| s=40, # 固定大小 | |
| alpha=0.7 | |
| ) | |
| # 只为每个类别添加一次图例(使用训练集标记) | |
| if dataset_type == 'train': | |
| legend_handles.append(scatter) | |
| # 添加数据集类型的图例 | |
| from matplotlib.lines import Line2D | |
| dataset_legend = [ | |
| Line2D([0], [0], marker='o', color='w', markerfacecolor='gray', markersize=10, label='Train'), | |
| Line2D([0], [0], marker='s', color='w', markerfacecolor='gray', markersize=10, label='Validation'), | |
| Line2D([0], [0], marker='^', color='w', markerfacecolor='gray', markersize=10, label='Test') | |
| ] | |
| # 创建图例 | |
| class_legend = plt.legend(handles=legend_handles, labels=class_names, | |
| title="Classes", loc="upper right", fontsize=10, ncol=1) | |
| plt.gca().add_artist(class_legend) | |
| plt.legend(handles=dataset_legend, title="Dataset Types", | |
| loc="upper left", fontsize=10) | |
| plt.title(f'Probability {method_name} Visualization', fontsize=16) | |
| plt.xlabel(f'{method_name} Dimension 1', fontsize=12) | |
| plt.ylabel(f'{method_name} Dimension 2', fontsize=12) | |
| plt.grid(alpha=0.2) | |
| # 保存为矢量图和位图 | |
| filename = os.path.join(output_dir, f"prob_{method.lower()}_visualization.pdf") | |
| plt.savefig(filename) | |
| plt.savefig(filename.replace('.pdf', '.png'), dpi=300) | |
| plt.close() # 关闭图形避免显示过多 | |
| print(f"Saved probability {method_name} plot to {filename}") | |
| return result_df | |
| # 执行不同降维方法 | |
| print("\nVisualizing with dimensionality reduction methods on probability vectors...") | |
| tsne_result = perform_prob_clustering(prob_vectors, all_labels, all_dataset_types, file_paths, method='tsne') | |
| pca_result = perform_prob_clustering(prob_vectors, all_labels, all_dataset_types, file_paths, method='pca') | |
| umap_result = perform_prob_clustering(prob_vectors, all_labels, all_dataset_types, file_paths, method='umap') | |
| # ---------------------- | |
| # 9. 使用独立类生成平均图像 | |
| # ---------------------- | |
| print("\nGenerating average leaf images for each class...") | |
| avg_image_generator = ClassAverageImageGenerator(img_height=IMG_HEIGHT, img_width=IMG_WIDTH) | |
| avg_image_generator.generate_all_class_averages( | |
| source_dir=source_dir, # 使用原始leaves目录 | |
| output_dir=os.path.join(output_dir, "class_average_images"), | |
| max_per_page=10 | |
| ) | |
| print("\n===== All analysis completed successfully! =====") | |
| # ---------------------- | |
| # 使用示例:单独使用平均图像生成类 | |
| # ---------------------- | |
| def generate_average_images_standalone(): | |
| """单独使用平均图像生成功能的示例""" | |
| generator = ClassAverageImageGenerator() | |
| # 为训练集生成平均图像 | |
| generator.generate_all_class_averages( | |
| source_dir=train_dir, | |
| output_dir="train_class_averages", | |
| max_per_page=8 | |
| ) | |
| # 为整个数据集生成平均图像 | |
| generator.generate_all_class_averages( | |
| source_dir=source_dir, | |
| output_dir="all_class_averages", | |
| max_per_page=10 | |
| ) | |
| # 取消注释下面的行来单独运行平均图像生成 | |
| # generate_average_images_standalone() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment