[pytorch] (数据集平衡)鸢尾花-优化01

1、观察数据集的平衡性

# 1、观察源数据中train 的总数
def count_total_size(dir_path,name=''):
    sub_folders = [f for f in os.listdir(dir_path) if os.path.isdir(os.path.join(dir_path,f))]
    
    lens = []
    for sub_folder in sub_folders:
        image_files = [f for f in os.listdir(os.path.join(dir_path,sub_folder)) if f.endswith(('.jpg','.png','.jpeg'))]
        len_images = len(image_files)
        lens.append({"name":sub_folder,"num":len_images})
#         print(f'Total number of {name} images:{len_images}')
    return lens

# 画柱状图进行比较
def plot_image_counts(dir_path, title='Distribution of Image Counts'):
    num_objs = count_total_size(dir_path, dir_path.split('/')[-1])
    names = [obj['name'] for obj in num_objs]
    nums = [obj['num'] for obj in num_objs]

    fig, ax = plt.subplots()
    bar_container = ax.bar(names, nums)
    fig.set_size_inches(25, 10)
    plt.ylabel('Number of Images')
    plt.title(title)
    plt.xticks(rotation=0, ha='center')
    plt.tight_layout()
    plt.show()

plot_image_counts(train_dir,"before dataset adjust")

通过观察数据集的分布情况可以知道目前的数据集是非常不平衡的,最低数据集种类和最高数据集种类的的差距有10倍左右,所以可以考虑平衡数据集来对训练结果进行优化操作。

2、对数据集进行调整

  • 计算类别平均数
  • 高于平均数的类别裁剪图片
  • 低于平均数的类别增强图片
# 使用数据增强增加欠缺类别的硬数据

import os
import shutil
import torch
import torchvision.transforms as transforms
from PIL import Image

def process_data_without_modifying_source(source_folder, target_folder):
    if os.path.exists(target_folder):
        try:
            shutil.rmtree(target_folder)
            print(f"文件夹 {target_folder} 已成功删除。")
        except Exception as e:
            print(f"删除文件夹时出现错误:{e}")
    
    # 如果目标文件夹不存在,则创建
    if not os.path.exists(target_folder):
        os.makedirs(target_folder)

    # 复制源文件夹结构到目标文件夹
    for root, dirs, files in os.walk(source_folder):
        for dir_name in dirs:
            new_dir_path = os.path.join(target_folder, os.path.relpath(os.path.join(root, dir_name), source_folder))
            if not os.path.exists(new_dir_path):
                os.makedirs(new_dir_path)
        for file_name in files:
            if file_name.endswith(('.jpg', '.png', '.jpeg')):
                source_file_path = os.path.join(root, file_name)
                target_file_path = os.path.join(target_folder, os.path.relpath(root, source_folder), file_name)
                shutil.copy2(source_file_path, target_file_path)

    categories = [d for d in os.listdir(target_folder) if os.path.isdir(os.path.join(target_folder, d))]
    all_images = []
    for category in categories:
        category_path = os.path.join(target_folder, category)
        images = [f for f in os.listdir(category_path) if f.endswith(('.jpg', '.png', '.jpeg'))]
        all_images.append((category, len(images), images))

    total_images = sum([count for _, count, _ in all_images])
    average = total_images // len(categories)
    
    print("########平均数据数量是:",average)

    transform = transforms.Compose([
        transforms.RandomRotation(30),
        transforms.RandomHorizontalFlip(),
        transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    ])

    for category, count, images in all_images:
        if len(images) < average:
            num_to_generate = average - len(images)
            generated_images = 0
            for _ in range(num_to_generate):
                random_img_index = torch.randint(0, len(images), (1,)).item()
                img_path = os.path.join(target_folder, category, images[random_img_index])
                img = Image.open(img_path)
                img_tensor = transform(img)
                save_path = os.path.join(target_folder, category, f'augmented_{generated_images}.jpg')
                img_tensor.save(save_path)
                generated_images += 1
        elif len(images) > average:
            excess_count = len(images) - average
            for _ in range(excess_count):
                random_img_index = torch.randint(0, len(images), (1,)).item()
                img_path = os.path.join(target_folder, category, images[random_img_index])
                os.remove(img_path)
                images.pop(random_img_index)

    return target_folder

source_folder = train_dir
target_folder = work_space_dir+"/dataset/train"
processed_folder = process_data_without_modifying_source(source_folder, target_folder)
print(processed_folder)
# 将训练集的文件夹改为调整后的train文件
train_dir = processed_folder
########平均数据数量是: 64

数据集分布调整后的分布情况

可以看出数据集已经被平衡了,这里采用的是增强数据的方式,将增强后的数据持久化到本地的方式补齐缺少的数据集

3、前后训练差别

调整前

Training on cuda
Number of epochs: 20
------------------------------0 epoch---------------------------------
Current Learning Rate: [具体学习率]
Train Loss: 4.1048, Acc: 0.1241
Valid Loss: 3.6123, Acc: 0.1932
------------------------------10 epoch---------------------------------
Current Learning Rate: [具体学习率]
Train Loss: 1.8655, Acc: 0.5386
Valid Loss: 2.5691, Acc: 0.3851
------------------------------19 epoch---------------------------------
Current Learning Rate: [具体学习率]
Train Loss: 1.6926, Acc: 0.5609
Valid Loss: 2.5491, Acc: 0.3912

调整后

Training on cuda
Number of epochs: 20
------------------------------0 epoch---------------------------------
Train Loss: 1.7026, Acc: 0.5598
Valid Loss: 2.5916, Acc: 0.4059
------------------------------10 epoch---------------------------------
Train Loss: 1.6097, Acc: 0.5818
Valid Loss: 2.6268, Acc: 0.3851
------------------------------19 epoch---------------------------------
Train Loss: 1.5454, Acc: 0.5920
Valid Loss: 2.6243, Acc: 0.3973

从训练结果来看,在数据集平衡后,验证集的准确率有小幅提升,但仍然没有显著改善。对比两次训练的情况可以看到:

原始训练结果:

  • 验证集准确率 从 0.1932 提升到 0.3912,最终稳定在 0.3912 左右。
  • 验证集损失 从 3.6123 下降到 2.5491。

平衡数据集后的训练结果:

  • 验证集准确率 在 0.4059 附近波动,最高达到了 0.4095。
  • 验证集损失 在 2.5916 附近变化,最终略微下降到 2.6243。

主要观察:

  1. 验证集准确率提升不大: 数据集平衡后,验证集的准确率有小幅波动提升,达到了约 40% 的准确率。这说明数据集的平衡对模型的泛化能力有一定帮助,但并不是根本性的问题。
  2. 验证集损失趋于平稳: 虽然平衡后的验证损失值较原始数据的稍有波动,但总体趋势差别不大,损失仍维持在类似的区间,这表明模型在处理验证集时没有显著过拟合或欠拟合。
  3. 可能的瓶颈
    • 模型结构和深度:如果模型的深度或容量不足,可能无法有效捕捉更复杂的特征,即使数据集经过了平衡,也难以进一步提升性能。
    • 特征工程:如果数据特征不够区分明显,模型可能无法有效学习,这时可以尝试进一步的数据预处理和特征提取。

下面这张图展示了平衡后训练集和验证集在训练过程中的损失(Loss)和准确率(Accuracy)随 epoch 的变化情况:

一、损失曲线分析

  1. Train Loss(训练集损失):整体呈现下降趋势,说明模型在训练过程中逐渐优化,对训练数据的拟合能力不断提高。从开始较高的值逐渐下降,表明模型在不断学习数据中的模式,降低预测误差。
  2. Valid Loss(验证集损失):波动相对较大,但也有一定的下降趋势。与训练集损失相比,验证集损失没有像训练集损失那样持续稳定下降,这可能意味着模型在训练过程中存在一定程度的过拟合风险,即模型在训练集上表现良好,但在新的、未见过的数据(验证集)上表现不稳定。

二、准确率曲线分析

  1. Train Accuracy(训练集准确率):随着 epoch 的增加逐步上升,表明模型在训练数据上的预测准确性不断提高。这是因为模型在不断调整参数以更好地拟合训练数据。
  2. Valid Accuracy(验证集准确率):整体上升趋势较为缓慢且波动。与训练集准确率相比,验证集准确率相对较低,这进一步印证了可能存在过拟合的情况。同时,波动表明模型在不同 epoch 对验证集的适应性不稳定。

三、总体结论

综合来看,模型在训练过程中虽然在训练集上表现出一定的优化趋势,但在验证集上的表现不够稳定且与训练集存在一定差距。这提示可能需要采取一些措施来防止过拟合,如增加数据增强、使用正则化技术、调整模型复杂度或优化超参数等,以提高模型的泛化能力,使其在新的数据上也能有较好的表现。

发表评论

您的邮箱地址不会被公开。 必填项已用 * 标注

滚动至顶部