Skip to content

Instantly share code, notes, and snippets.

@MarchLiu
Created November 29, 2025 06:16
Show Gist options
  • Select an option

  • Save MarchLiu/f51f472969ab37ec2b22d1e38fb57f75 to your computer and use it in GitHub Desktop.

Select an option

Save MarchLiu/f51f472969ab37ec2b22d1e38fb57f75 to your computer and use it in GitHub Desktop.
# Code for recognizing ten groups of photos. The codes for recognizing two, three, and five groups of photos follows the same structure, differing only in the number of groups.
# 1. 自动数据分割:从leaves文件夹自动分割train/val/test
# 2. 独立平均图片生成类
import os
import sys
import warnings
import random
import shutil
from sklearn.model_selection import train_test_split
# 设置环境变量减少日志
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
# 处理可能的导入错误
try:
import numpy as np
except ImportError:
print("Error: NumPy not installed. Please install NumPy 1.19.5")
sys.exit(1)
try:
import tensorflow as tf
except ImportError:
print("Error: TensorFlow not installed. Please install TensorFlow 2.5.0")
sys.exit(1)
# 尝试导入其他库,提供友好错误提示
optional_libs = ['pandas', 'umap', 'PIL']
missing_libs = []
for lib in optional_libs:
try:
__import__(lib)
except ImportError:
missing_libs.append(lib)
if missing_libs:
print(f"Warning: Missing optional libraries: {', '.join(missing_libs)}")
print("Some advanced features will be disabled.")
DISABLE_ADVANCED_FEATURES = True
else:
DISABLE_ADVANCED_FEATURES = False
import pandas as pd
from PIL import Image
import umap
from sklearn.manifold import TSNE
# 忽略特定警告
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout
from tensorflow.keras.models import Sequential, load_model, Model
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score, roc_curve, auc
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import umap
from umap import UMAP
from itertools import cycle
from PIL import Image
import glob
# 处理grpcio兼容性警告
try:
from grpc._cython import cygrpc
except ImportError:
pass # 忽略此错误
from average_image_generator import ClassAverageImageGenerator
# 设置全局字体为英文
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = ['Arial', 'Helvetica', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['savefig.format'] = 'pdf' # 保存为矢量图
plt.rcParams['pdf.fonttype'] = 42 # 确保文本可编辑
plt.rcParams['ps.fonttype'] = 42 # 确保文本可编辑
# 禁用GPU(如果需要)
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
tf.config.set_visible_devices([], 'GPU')
# 验证设备设置
print("Available devices:", tf.config.list_physical_devices())
print("GPU available:", tf.config.list_physical_devices('GPU'))
# ----------------------
# 0. 设置工作目录
# ----------------------
os.chdir(".../working/dir")
# ----------------------
# 新增:自动数据分割类
# ----------------------
class DataSplitter:
def __init__(self, source_dir="leaves", output_base_dir="split_data",
split_ratios=[0.6, 0.2, 0.2], random_seed=42, clean_output=True):
"""
初始化数据分割器
Args:
source_dir: 源数据目录,包含各个类别的子文件夹
output_base_dir: 输出目录
split_ratios: 分割比例 [train, val, test]
random_seed: 随机种子
clean_output: 是否在分割前清空输出目录
"""
self.source_dir = source_dir
self.output_base_dir = output_base_dir
self.split_ratios = split_ratios
self.random_seed = random_seed
self.clean_output = clean_output
# 验证分割比例
if abs(sum(split_ratios) - 1.0) > 0.01:
raise ValueError(f"Split ratios must sum to 1.0, got {sum(split_ratios)}")
# 设置输出目录
self.train_dir = os.path.join(output_base_dir, "train")
self.val_dir = os.path.join(output_base_dir, "val")
self.test_dir = os.path.join(output_base_dir, "test")
# 设置随机种子
random.seed(random_seed)
np.random.seed(random_seed)
def split_data(self):
"""执行数据分割"""
print("开始自动数据分割...")
print(f"源目录: {self.source_dir}")
print(f"输出目录: {self.output_base_dir}")
print(f"分割比例: Train {self.split_ratios[0] * 100}%, "
f"Val {self.split_ratios[1] * 100}%, Test {self.split_ratios[2] * 100}%")
# 清空或创建输出目录
self._prepare_output_directories()
# 获取所有类别
class_folders = [f for f in os.listdir(self.source_dir)
if os.path.isdir(os.path.join(self.source_dir, f))]
class_folders.sort()
print(f"发现 {len(class_folders)} 个类别: {class_folders}")
total_files = 0
detailed_stats = {}
for class_folder in class_folders:
class_path = os.path.join(self.source_dir, class_folder)
# 获取该类别的所有图片文件
image_files = self._get_image_files(class_path)
if not image_files:
print(f"警告: 类别 {class_folder} 中没有找到图片文件")
continue
# 随机打乱文件顺序
random.shuffle(image_files)
# 计算分割点
n_total = len(image_files)
n_train = int(n_total * self.split_ratios[0])
n_val = int(n_total * self.split_ratios[1])
# 分割文件
train_files = image_files[:n_train]
val_files = image_files[n_train:n_train + n_val]
test_files = image_files[n_train + n_val:]
# 复制文件到相应目录
self._copy_files(train_files, class_folder, self.train_dir)
self._copy_files(val_files, class_folder, self.val_dir)
self._copy_files(test_files, class_folder, self.test_dir)
total_files += n_total
# 记录详细统计
detailed_stats[class_folder] = {
'total': n_total,
'train': len(train_files),
'val': len(val_files),
'test': len(test_files)
}
print(f"类别 {class_folder}: {n_total} 张图片 -> "
f"Train: {len(train_files)}, Val: {len(val_files)}, Test: {len(test_files)}")
print(f"\n数据分割完成! 总共处理 {total_files} 张图片")
# 验证分割结果
self._verify_split_results(detailed_stats)
return self.train_dir, self.val_dir, self.test_dir
def _prepare_output_directories(self):
"""准备输出目录(清空或创建)"""
if self.clean_output and os.path.exists(self.output_base_dir):
print(f"清空输出目录: {self.output_base_dir}")
shutil.rmtree(self.output_base_dir)
# 创建输出目录
os.makedirs(self.train_dir, exist_ok=True)
os.makedirs(self.val_dir, exist_ok=True)
os.makedirs(self.test_dir, exist_ok=True)
print(f"创建输出目录: {self.output_base_dir}")
def _get_image_files(self, class_path):
"""获取目录中的所有图片文件"""
image_files = []
for ext in ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tiff']:
image_files.extend(glob.glob(os.path.join(class_path, ext)))
image_files.extend(glob.glob(os.path.join(class_path, ext.upper())))
return image_files
def _copy_files(self, file_list, class_name, target_dir):
"""复制文件到目标目录"""
class_target_dir = os.path.join(target_dir, class_name)
os.makedirs(class_target_dir, exist_ok=True)
for file_path in file_list:
filename = os.path.basename(file_path)
target_path = os.path.join(class_target_dir, filename)
shutil.copy2(file_path, target_path)
def _verify_split_results(self, detailed_stats):
"""验证分割结果"""
print("\n=== 分割结果验证 ===")
# 验证每个分割目录的文件数量
for split_name, split_dir in [('train', self.train_dir),
('val', self.val_dir),
('test', self.test_dir)]:
actual_count = 0
for class_folder in os.listdir(split_dir):
class_path = os.path.join(split_dir, class_folder)
if os.path.isdir(class_path):
files = [f for f in os.listdir(class_path)
if os.path.isfile(os.path.join(class_path, f))]
actual_count += len(files)
expected_count = sum(stats[split_name] for stats in detailed_stats.values())
print(f"{split_name}: 预期 {expected_count}, 实际 {actual_count}")
if actual_count != expected_count:
print(f"警告: {split_name} 目录文件数量不匹配!")
# ----------------------
# 1. 自动数据分割
# ----------------------
# 配置分割参数
split_ratios = [0.6, 0.2, 0.2] # train: 60%, val: 20%, test: 20%
source_dir = "leaves" # 原始数据目录
output_base_dir = "split_data" # 分割后数据输出目录
# 执行数据分割
splitter = DataSplitter(
source_dir=source_dir,
output_base_dir=output_base_dir,
split_ratios=split_ratios,
random_seed=42,
clean_output=True
)
train_dir, val_dir, test_dir = splitter.split_data()
# ----------------------
# 2. 设置数据路径(使用自动分割的结果)
# ----------------------
model_save_path = "leaf_classifier_10class.h5" # 模型保存路径
output_dir = "analysis_results" # 分析结果输出目录
# 创建输出目录
os.makedirs(output_dir, exist_ok=True)
# 检查路径是否存在
if not os.path.exists(train_dir):
raise FileNotFoundError(f"Training directory {train_dir} not found!")
if not os.path.exists(val_dir):
raise FileNotFoundError(f"Validation directory {val_dir} not found!")
if not os.path.exists(test_dir):
raise FileNotFoundError(f"Test directory {test_dir} not found!")
# ----------------------
# 3. 创建数据生成器 - 修改为支持十分类
# ----------------------
IMG_HEIGHT = 224 # 图像高度
IMG_WIDTH = 224 # 图像宽度
BATCH_SIZE = 32
NUM_CLASSES = 10 # 10个类别
# 数据增强配置(适当增强以应对更多类别)
train_datagen = ImageDataGenerator(
rescale=1.0 / 255,
rotation_range=8,
width_shift_range=0.13,
height_shift_range=0.13,
zoom_range=0.1, #used to be 0.1
channel_shift_range=15,
shear_range=0.1,
fill_mode='reflect',
cval=0.0,
vertical_flip=False,
horizontal_flip=False,
brightness_range=[0.9, 1.1]
)
val_datagen = ImageDataGenerator(rescale=1.0 / 255)
test_datagen = ImageDataGenerator(rescale=1.0 / 255)
# 创建主生成器(用于训练)
train_generator = train_datagen.flow_from_directory(
train_dir,
target_size=(IMG_HEIGHT, IMG_WIDTH),
batch_size=BATCH_SIZE,
class_mode='categorical',
shuffle=True # 训练时打乱数据很重要
)
val_generator = val_datagen.flow_from_directory(
val_dir,
target_size=(IMG_HEIGHT, IMG_WIDTH),
batch_size=BATCH_SIZE,
class_mode='categorical',
shuffle=False
)
test_generator = test_datagen.flow_from_directory(
test_dir,
target_size=(IMG_HEIGHT, IMG_WIDTH),
batch_size=BATCH_SIZE,
class_mode='categorical',
shuffle=False
)
# 创建专用的分析生成器(不打乱顺序,用于后处理)
print("\n创建专用的分析生成器...")
analysis_train_generator = train_datagen.flow_from_directory(
train_dir,
target_size=(IMG_HEIGHT, IMG_WIDTH),
batch_size=BATCH_SIZE,
class_mode='categorical',
shuffle=False, # 关键:不打乱顺序
seed=42
)
analysis_val_generator = val_datagen.flow_from_directory(
val_dir,
target_size=(IMG_HEIGHT, IMG_WIDTH),
batch_size=BATCH_SIZE,
class_mode='categorical',
shuffle=False,
seed=42
)
analysis_test_generator = test_datagen.flow_from_directory(
test_dir,
target_size=(IMG_HEIGHT, IMG_WIDTH),
batch_size=BATCH_SIZE,
class_mode='categorical',
shuffle=False,
seed=42
)
# 确保分析生成器与主生成器有相同的类别映射
assert analysis_train_generator.class_indices == train_generator.class_indices
assert analysis_val_generator.class_indices == val_generator.class_indices
assert analysis_test_generator.class_indices == test_generator.class_indices
print("分析生成器创建完成,类别映射一致")
# 获取类别名称和索引
class_names = list(train_generator.class_indices.keys())
class_indices = train_generator.class_indices
print(f"Class indices: {class_indices}")
print(f"Class names: {class_names}")
# ----------------------
# 4. 构建模型 - 修改为十分类(增加模型复杂度)
# ----------------------
def build_model():
# 使用预训练模型作为特征提取器
base_model = MobileNetV2(
input_shape=(IMG_HEIGHT, IMG_WIDTH, 3),
include_top=False,
weights='imagenet'
)
base_model.trainable = False
# 添加全局平均池化层以提取特征
x = base_model.output
x = GlobalAveragePooling2D()(x)
# 保存特征提取模型
feature_model = Model(inputs=base_model.input, outputs=x)
# 分类头
y = Dense(1024, activation='relu')(x)
y = Dropout(0.5)(y)
y = Dense(512, activation='relu')(y)
y = Dropout(0.3)(y)
y = Dense(256, activation='relu')(y)
y = Dense(128, activation='relu')(y)
predictions = Dense(NUM_CLASSES, activation='softmax')(y)
# 完整模型
model = Model(inputs=base_model.input, outputs=predictions)
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
loss='categorical_crossentropy',
metrics=['accuracy']
)
return model, feature_model # 返回分类模型和特征提取模型
model, feature_model = build_model()
model.summary() # 打印模型结构
# ----------------------
# 5. 训练模型
# ----------------------
# 回调函数配置(增加耐心值)
callbacks = [
EarlyStopping(patience=10, monitor='val_loss', restore_best_weights=True),
ModelCheckpoint(model_save_path, save_best_only=True)
]
# 开始训练
print("\nStarting training for 10-class classification...")
history = model.fit(
train_generator,
workers=32,
epochs=512, # 增加最大epoch数
validation_data=val_generator,
callbacks=callbacks
)
# ----------------------
# 6. 评估模型 - 英文标签
# ----------------------
# 加载最佳模型
best_model = load_model(model_save_path)
best_model.trainable = False # 确保不更新权重
# 在验证集上评估
loss, accuracy = best_model.evaluate(val_generator)
print(f"\nValidation accuracy: {accuracy * 100:.2f}%")
# 绘制训练曲线并保存为矢量图
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Model Accuracy (10 Classes)')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(loc='lower right')
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss (10 Classes)')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(loc='upper right')
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'training_history_10class.pdf'))
plt.savefig(os.path.join(output_dir, 'training_history_10class.png'), dpi=300)
plt.show()
# 获取验证集所有预测结果和真实标签
print("\nGenerating predictions...")
val_preds = best_model.predict(val_generator)
val_pred_classes = np.argmax(val_preds, axis=1) # 获取预测类别
val_true_classes = val_generator.classes # 真实类别索引
# 1. 整体评估指标
print("\n===== Overall Metrics =====")
print(f"Validation accuracy: {accuracy_score(val_true_classes, val_pred_classes) * 100:.2f}%")
# 2. 混淆矩阵
cm = confusion_matrix(val_true_classes, val_pred_classes)
print("\nConfusion Matrix:")
print(cm)
# 可视化混淆矩阵(调整以适应更多类别)
plt.figure(figsize=(14, 12))
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Confusion Matrix (10 Classes)', fontsize=16)
plt.colorbar()
tick_marks = np.arange(len(class_names))
plt.xticks(tick_marks, class_names, rotation=45, ha='right', fontsize=10)
plt.yticks(tick_marks, class_names, fontsize=10)
# 添加数值标签
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
plt.text(j, i, format(cm[i, j], 'd'),
horizontalalignment="center",
verticalalignment="center",
color="white" if cm[i, j] > thresh else "black",
fontsize=9)
plt.xlabel('Predicted Label', fontsize=12)
plt.ylabel('True Label', fontsize=12)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'confusion_matrix_10class.pdf'))
plt.savefig(os.path.join(output_dir, 'confusion_matrix_10class.png'), dpi=300, bbox_inches='tight')
plt.show()
# 3. 分类报告
print("\nClassification Report:")
report = classification_report(
val_true_classes,
val_pred_classes,
target_names=class_names,
output_dict=True # 输出为字典格式便于保存
)
print(pd.DataFrame(report).transpose())
# 保存分类报告
report_df = pd.DataFrame(report).transpose()
report_df.to_csv(os.path.join(output_dir, 'classification_report.csv'))
report_df.to_excel(os.path.join(output_dir, 'classification_report.xlsx'))
# 4. 计算每个类别的准确率
print("\n===== Class-wise Performance =====")
class_accuracies = []
class_recalls = []
class_precisions = []
for i, class_name in enumerate(class_names):
# 找出当前类别的样本索引
class_indices = np.where(val_true_classes == i)[0]
# 计算当前类别的准确率
correct = np.sum(val_pred_classes[class_indices] == i)
total = len(class_indices)
acc = correct / total if total > 0 else 0
class_accuracies.append(acc)
# 计算召回率
recall = correct / total if total > 0 else 0
class_recalls.append(recall)
# 计算精确率
predicted_indices = np.where(val_pred_classes == i)[0]
precision = correct / len(predicted_indices) if len(predicted_indices) > 0 else 0
class_precisions.append(precision)
print(f"Class '{class_name}':")
print(f" Accuracy: {acc * 100:.2f}% ({correct}/{total})")
print(f" Recall: {recall * 100:.2f}%")
print(f" Precision: {precision * 100:.2f}%")
# 可视化各类别性能(水平柱状图更适合更多类别)
plt.figure(figsize=(12, 10))
y_pos = np.arange(len(class_names))
# 创建三个水平柱状图
plt.barh(y_pos - 0.25, class_accuracies, height=0.25, label='Accuracy', color='#1f77b4')
plt.barh(y_pos, class_recalls, height=0.25, label='Recall', color='#ff7f0e')
plt.barh(y_pos + 0.25, class_precisions, height=0.25, label='Precision', color='#2ca02c')
plt.xlabel('Score', fontsize=12)
plt.title('Class-wise Performance Metrics (10 Classes)', fontsize=16)
plt.yticks(y_pos, class_names, fontsize=10)
plt.xlim(0, 1.1)
plt.legend(loc='lower right', fontsize=10)
plt.grid(axis='x', linestyle='--', alpha=0.7)
# 添加数值标签
for i, v in enumerate(class_accuracies):
plt.text(v + 0.02, i - 0.25, f"{v:.2f}", color='black', fontsize=9)
for i, v in enumerate(class_recalls):
plt.text(v + 0.02, i, f"{v:.2f}", color='black', fontsize=9)
for i, v in enumerate(class_precisions):
plt.text(v + 0.02, i + 0.25, f"{v:.2f}", color='black', fontsize=9)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'class_performance_10class.pdf'))
plt.savefig(os.path.join(output_dir, 'class_performance_10class.png'), dpi=300)
plt.show()
# 5. ROC曲线(十分类)
plt.figure(figsize=(12, 10))
colors = cycle(['blue', 'red', 'green', 'purple', 'orange',
'cyan', 'magenta', 'lime', 'brown', 'pink'])
# 计算每个类别的ROC曲线和AUC
for i, color in zip(range(NUM_CLASSES), colors):
# 获取当前类别的二分类标签
y_true_i = np.where(val_true_classes == i, 1, 0)
y_score_i = val_preds[:, i]
# 计算ROC曲线
fpr, tpr, _ = roc_curve(y_true_i, y_score_i)
roc_auc = auc(fpr, tpr)
# 绘制ROC曲线
plt.plot(fpr, tpr, color=color, lw=2,
label=f'{class_names[i]} (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], 'k--', lw=2)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate', fontsize=12)
plt.ylabel('True Positive Rate', fontsize=12)
plt.title('ROC Curves for 10-Class Classification', fontsize=16)
plt.legend(loc="lower right", fontsize=10)
plt.grid(alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'roc_curves_10class.pdf'))
plt.savefig(os.path.join(output_dir, 'roc_curves_10class.png'), dpi=300)
plt.show()
# ----------------------
# 7. 在测试集上最终评估
# ----------------------
print("\n===== Test Set Evaluation =====")
test_loss, test_accuracy = best_model.evaluate(test_generator)
print(f"Test accuracy: {test_accuracy * 100:.2f}%")
# 获取测试集所有预测结果和真实标签
test_preds = best_model.predict(test_generator)
test_pred_classes = np.argmax(test_preds, axis=1)
test_true_classes = test_generator.classes
# 测试集混淆矩阵
test_cm = confusion_matrix(test_true_classes, test_pred_classes)
print("\nTest Confusion Matrix:")
print(test_cm)
# 可视化测试集混淆矩阵
plt.figure(figsize=(14, 12))
plt.imshow(test_cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Test Confusion Matrix (10 Classes)', fontsize=16)
plt.colorbar()
tick_marks = np.arange(len(class_names))
plt.xticks(tick_marks, class_names, rotation=45, ha='right', fontsize=10)
plt.yticks(tick_marks, class_names, fontsize=10)
thresh = test_cm.max() / 2.
for i in range(test_cm.shape[0]):
for j in range(test_cm.shape[1]):
plt.text(j, i, format(test_cm[i, j], 'd'),
horizontalalignment="center",
verticalalignment="center",
color="white" if test_cm[i, j] > thresh else "black",
fontsize=9)
plt.xlabel('Predicted Label', fontsize=12)
plt.ylabel('True Label', fontsize=12)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'test_confusion_matrix_10class.pdf'))
plt.savefig(os.path.join(output_dir, 'test_confusion_matrix_10class.png'), dpi=300, bbox_inches='tight')
plt.show()
# 测试集分类报告
print("\nTest Classification Report:")
test_report = classification_report(
test_true_classes,
test_pred_classes,
target_names=class_names,
output_dict=True
)
print(pd.DataFrame(test_report).transpose())
# 保存测试集分类报告
test_report_df = pd.DataFrame(test_report).transpose()
test_report_df.to_csv(os.path.join(output_dir, 'test_classification_report.csv'))
# 测试集各类别性能
print("\n===== Test Class-wise Performance =====")
test_class_accuracies = []
for i, class_name in enumerate(class_names):
class_indices = np.where(test_true_classes == i)[0]
correct = np.sum(test_pred_classes[class_indices] == i)
total = len(class_indices)
acc = correct / total if total > 0 else 0
test_class_accuracies.append(acc)
print(f"Class '{class_name}': Accuracy = {acc * 100:.2f}% ({correct}/{total})")
# 可视化训练/验证/测试准确率对比
plt.figure(figsize=(10, 6))
x = np.arange(len(class_names))
width = 0.25
# 获取之前计算的验证集各类别准确率(class_accuracies)
plt.bar(x - width, class_accuracies, width, label='Validation', color='#1f77b4')
plt.bar(x, test_class_accuracies, width, label='Test', color='#ff7f0e')
plt.ylabel('Accuracy', fontsize=12)
plt.title('Validation vs Test Accuracy per Class', fontsize=14)
plt.xticks(x, class_names, fontsize=10, rotation=30)
plt.ylim(0, 1.1)
plt.legend(fontsize=10)
plt.grid(axis='y', linestyle='--', alpha=0.7)
# 添加数值标签
for i, v in enumerate(class_accuracies):
plt.text(i - width, v + 0.02, f"{v:.2f}", ha='center', fontsize=9)
for i, v in enumerate(test_class_accuracies):
plt.text(i, v + 0.02, f"{v:.2f}", ha='center', fontsize=9)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'validation_vs_test_accuracy.pdf'))
plt.savefig(os.path.join(output_dir, 'validation_vs_test_accuracy.png'), dpi=300)
plt.show()
# 添加统计显著性检验
try:
from scipy.stats import binomtest
from statsmodels.stats.contingency_tables import mcnemar
from scipy.stats import chi2_contingency
# ... 统计检验函数保持不变 ...
# 由于长度限制,这里省略了统计检验函数的完整代码
# 你可以从之前的版本中复制这些函数
except ImportError:
print("\n警告:未找到scipy或statsmodels库,无法进行统计检验")
print("请安装这些库以获得更全面的分析:")
print("pip install scipy statsmodels")
# ----------------------
# 8. 训练后概率聚类分析(保留,删除原始图片聚类)
# ----------------------
print("\n===== Starting Advanced Analysis =====")
# 提取预测概率(用于训练后聚类)使用专用生成器
def extract_predictions(generator, model, dataset_name):
"""提取预测概率(使用专用生成器)"""
probabilities = []
true_labels = []
file_paths = []
# 重置生成器
generator.reset()
# 分批处理
for i in range(len(generator)):
if i % 10 == 0:
print(f"Processing {dataset_name} batch {i + 1}/{len(generator)}")
# 获取一批数据
batch_x, batch_y = generator[i]
# 提取预测概率
batch_probs = model.predict(batch_x, verbose=0)
probabilities.append(batch_probs)
true_labels.append(np.argmax(batch_y, axis=1))
# 获取当前批次的文件路径
start_idx = i * generator.batch_size
end_idx = start_idx + len(batch_y)
current_filepaths = generator.filepaths[start_idx:end_idx]
file_paths.extend(current_filepaths)
# 合并结果
probabilities = np.vstack(probabilities) if probabilities else np.array([])
true_labels = np.concatenate(true_labels) if true_labels else np.array([])
# 创建DataFrame
df = pd.DataFrame(probabilities, columns=[f"prob_class_{i}" for i in range(NUM_CLASSES)])
df['true_label'] = true_labels
df['predicted_label'] = np.argmax(probabilities, axis=1) if probabilities.size > 0 else -1
df['file_path'] = file_paths
df['dataset'] = dataset_name
return df
# 为所有数据集提取预测概率(使用专用生成器)
print("\nExtracting predictions for training set using analysis generator...")
train_df = extract_predictions(analysis_train_generator, best_model, "train")
print("Extracting predictions for validation set using analysis generator...")
val_df = extract_predictions(analysis_val_generator, best_model, "val")
print("Extracting predictions for test set using analysis generator...")
test_df = extract_predictions(analysis_test_generator, best_model, "test")
# 合并所有数据
all_df = pd.concat([train_df, val_df, test_df])
# 保存所有预测概率到Excel
all_df.to_excel(os.path.join(output_dir, 'all_predictions.xlsx'), index=False)
print(f"Saved all predictions to {os.path.join(output_dir, 'all_predictions.xlsx')}")
# 提取概率向量(10维)
prob_vectors = all_df[[f'prob_class_{i}' for i in range(NUM_CLASSES)]].values
file_paths = all_df['file_path'].values
all_labels = all_df['true_label'].values
all_dataset_types = all_df['dataset'].values
# 训练后概率聚类分析
def perform_prob_clustering(prob_vectors, labels, dataset_types, file_paths, method='tsne'):
"""对预测概率执行聚类分析"""
# 选择降维方法
if method == 'pca':
try:
from sklearn.decomposition import PCA
reducer = PCA(n_components=2, random_state=42)
method_name = 'PCA'
except ImportError as e:
print(f"PCA not available: {str(e)}")
return None
elif method == 'umap':
try:
from umap import UMAP
reducer = UMAP(n_components=2, random_state=42)
method_name = 'UMAP'
except Exception as e:
print(f"UMAP error: {str(e)}. Using PCA instead.")
try:
from sklearn.decomposition import PCA
reducer = PCA(n_components=2, random_state=42)
method_name = 'PCA (instead of UMAP)'
except ImportError:
print("PCA not available. Using t-SNE instead.")
try:
from sklearn.manifold import TSNE
reducer = TSNE(n_components=2, perplexity=30, random_state=42)
method_name = 't-SNE (instead of UMAP)'
except ImportError:
print("t-SNE not available. Dimensionality reduction skipped.")
return None
else: # 默认使用t-SNE
try:
from sklearn.manifold import TSNE
reducer = TSNE(n_components=2, perplexity=30, random_state=42)
method_name = 't-SNE'
except ImportError:
print("t-SNE not available. Using PCA instead.")
try:
from sklearn.decomposition import PCA
reducer = PCA(n_components=2, random_state=42)
method_name = 'PCA (instead of t-SNE)'
except ImportError:
print("PCA not available. Dimensionality reduction skipped.")
return None
# 执行降维
print(f"Running {method_name} on probability vectors...")
try:
reduced_features = reducer.fit_transform(prob_vectors)
except Exception as e:
print(f"Error during dimensionality reduction: {str(e)}")
return None
# 创建结果DataFrame
result_df = pd.DataFrame({
'file_path': file_paths,
'true_label': labels,
'predicted_label': np.argmax(prob_vectors, axis=1),
'dataset': dataset_types,
f'prob_{method_name}_X': reduced_features[:, 0],
f'prob_{method_name}_Y': reduced_features[:, 1]
})
# 保存CSV
csv_filename = os.path.join(output_dir, f'prob_{method_name}_coordinates.csv')
result_df.to_csv(csv_filename, index=False)
print(f"Saved probability {method_name} coordinates to {csv_filename}")
# 可视化
plt.figure(figsize=(12, 10))
# 使用与原始聚类相同的颜色方案
class_colors = [
'blue', 'red', 'green', 'purple', 'orange',
'cyan', 'magenta', 'lime', 'brown', 'pink'
]
# 数据集标记样式
markers = {
'train': 'o', # 圆圈
'val': 's', # 方块
'test': '^' # 三角
}
# 为每个类别创建图例条目
legend_handles = []
# 为每个类别和数据集类型创建散点图
for class_idx in range(NUM_CLASSES):
color = class_colors[class_idx]
for dataset_type in ['train', 'val', 'test']:
# 选择当前类别和数据集类型的点
mask = (labels == class_idx) & (dataset_types == dataset_type)
if np.sum(mask) > 0:
# 设置标记形状
marker = markers[dataset_type]
# 绘制点
scatter = plt.scatter(
reduced_features[mask, 0],
reduced_features[mask, 1],
color=color,
marker=marker,
s=40, # 固定大小
alpha=0.7
)
# 只为每个类别添加一次图例(使用训练集标记)
if dataset_type == 'train':
legend_handles.append(scatter)
# 添加数据集类型的图例
from matplotlib.lines import Line2D
dataset_legend = [
Line2D([0], [0], marker='o', color='w', markerfacecolor='gray', markersize=10, label='Train'),
Line2D([0], [0], marker='s', color='w', markerfacecolor='gray', markersize=10, label='Validation'),
Line2D([0], [0], marker='^', color='w', markerfacecolor='gray', markersize=10, label='Test')
]
# 创建图例
class_legend = plt.legend(handles=legend_handles, labels=class_names,
title="Classes", loc="upper right", fontsize=10, ncol=1)
plt.gca().add_artist(class_legend)
plt.legend(handles=dataset_legend, title="Dataset Types",
loc="upper left", fontsize=10)
plt.title(f'Probability {method_name} Visualization', fontsize=16)
plt.xlabel(f'{method_name} Dimension 1', fontsize=12)
plt.ylabel(f'{method_name} Dimension 2', fontsize=12)
plt.grid(alpha=0.2)
# 保存为矢量图和位图
filename = os.path.join(output_dir, f"prob_{method.lower()}_visualization.pdf")
plt.savefig(filename)
plt.savefig(filename.replace('.pdf', '.png'), dpi=300)
plt.close() # 关闭图形避免显示过多
print(f"Saved probability {method_name} plot to {filename}")
return result_df
# 执行不同降维方法
print("\nVisualizing with dimensionality reduction methods on probability vectors...")
tsne_result = perform_prob_clustering(prob_vectors, all_labels, all_dataset_types, file_paths, method='tsne')
pca_result = perform_prob_clustering(prob_vectors, all_labels, all_dataset_types, file_paths, method='pca')
umap_result = perform_prob_clustering(prob_vectors, all_labels, all_dataset_types, file_paths, method='umap')
# ----------------------
# 9. 使用独立类生成平均图像
# ----------------------
print("\nGenerating average leaf images for each class...")
avg_image_generator = ClassAverageImageGenerator(img_height=IMG_HEIGHT, img_width=IMG_WIDTH)
avg_image_generator.generate_all_class_averages(
source_dir=source_dir, # 使用原始leaves目录
output_dir=os.path.join(output_dir, "class_average_images"),
max_per_page=10
)
print("\n===== All analysis completed successfully! =====")
# ----------------------
# 使用示例:单独使用平均图像生成类
# ----------------------
def generate_average_images_standalone():
"""单独使用平均图像生成功能的示例"""
generator = ClassAverageImageGenerator()
# 为训练集生成平均图像
generator.generate_all_class_averages(
source_dir=train_dir,
output_dir="train_class_averages",
max_per_page=8
)
# 为整个数据集生成平均图像
generator.generate_all_class_averages(
source_dir=source_dir,
output_dir="all_class_averages",
max_per_page=10
)
# 取消注释下面的行来单独运行平均图像生成
# generate_average_images_standalone()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment