搜 索

如何训练一个YOLO目标检测模型

  • 69阅读
  • 2025年12月31日
  • 1评论
首页 / 作品/创作 / 正文

现在这个网站已经有了成品,请参考以下链接并试用
深度检测

前言:一个生物多样性的梦

作为一个 AI 架构师,我接到了一个很有意义的项目:开发一套生物多样性识别系统

目标很明确:

  • 识别各种鸟类(麻雀、喜鹊、白鹭、翠鸟...)
  • 识别各种鱼类(金鱼、鲤鱼、热带鱼...)
  • 识别各种小动物(松鼠、刺猬、兔子、蜥蜴...)

应用场景包括:

  • 野外红外相机的自动物种识别
  • 水族馆/动物园的智能监控
  • 自然保护区的生态监测
  • 公民科学项目的物种记录

技术选型自然是 YOLO——目标检测领域的当红炸子鸡,速度快、精度高、部署方便。

然而,理想很丰满,现实很骨感。


第一章:MacBook 的绝望——为什么我放弃了本地训练

1.1 最初的尝试

作为一个 MacBook Pro 用户(M4 Pro,18GB 内存),我天真地以为:

"Apple Silicon 不是很强吗?MPS 加速不是很快吗?训练个小模型应该没问题吧?"

于是我开始了尝试:

from ultralytics import YOLO

# 加载预训练模型
model = YOLO('yolov8n.pt')

# 开始训练
model.train(
    data='biodiversity.yaml',
    epochs=100,
    imgsz=640,
    batch=16,
    device='mps'  # Apple Silicon GPU
)

然后...

RuntimeError: MPS backend out of memory

好吧,把 batch 改成 8:

# 能跑了,但是...
Epoch 1/100: 100%|██████████| 500/500 [15:32<00:00, 1.86s/it]

一个 epoch 要 15 分钟,100 个 epoch 就是 25 小时!

而且这还只是最小的 YOLOv8n,如果换成 YOLOv8m 或 YOLOv8l...

1.2 MacBook 训练的残酷现实

配置YOLOv8nYOLOv8sYOLOv8m
MacBook M4 Pro 18GB勉强能跑OOM别想了
每 epoch 时间~15min--
100 epochs 总时间~25h--
风扇噪音起飞--
能否日常使用不能--

问题总结:

  1. 内存/显存不足:18GB 统一内存要同时跑系统和训练,捉襟见肘
  2. MPS 后端不成熟:经常遇到兼容性问题,某些操作不支持
  3. 速度太慢:和 NVIDIA GPU 比,差距是数量级的
  4. 影响日常使用:训练时电脑基本不能干别的事

1.3 痛定思痛

经过一番挣扎,我接受了现实:

深度学习训练,还得是 NVIDIA GPU。

MacBook 适合做推理、做轻量级实验,但大规模训练?还是交给专业的来吧。

于是,我把目光投向了云 GPU 服务。


第二章:AutoDL 初体验——云上炼丹的正确姿势

2.1 为什么选择 AutoDL

对比了几个云 GPU 服务:

服务商优点缺点
AutoDL便宜、简单、国内访问快高峰期可能没卡
AWS/GCP稳定、配置灵活贵、需要信用卡、配置复杂
阿里云/腾讯云大厂背书相对较贵
Colab免费限时、断连、速度不稳定

最终选择 AutoDL,原因:

  1. 价格实惠:RTX 4090 大约 2-3 元/小时
  2. 操作简单:网页端直接操作,支持 JupyterLab 和 SSH
  3. 预装环境:有现成的 PyTorch 镜像,开箱即用
  4. 数据传输快:国内服务器,上传下载速度有保障

2.2 创建实例

  1. 注册 AutoDL 账号(https://www.autodl.com
  2. 充值(支持支付宝,先充个 50-100 块试试)
  3. 创建实例:

    • GPU:RTX 4090(24GB 显存,性价比之王)
    • 镜像:PyTorch 2.0 + Python 3.10 + CUDA 11.8
    • 硬盘:系统盘默认,数据盘按需(我选了 50GB)

2.3 环境配置

SSH 连接到服务器后,配置训练环境:

# 1. 确认 GPU 状态
nvidia-smi

# 输出应该看到 RTX 4090,显存 24GB
# +-----------------------------------------------------------------------------+
# | NVIDIA-SMI 525.xx.xx    Driver Version: 525.xx.xx    CUDA Version: 12.0     |
# |-------------------------------+----------------------+----------------------+
# | GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
# | Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
# |===============================+======================+======================|
# |   0  NVIDIA GeForce ...  Off  | 00000000:00:00.0 Off |                  N/A |
# |  0%   30C    P8    10W / 450W |      0MiB / 24576MiB |      0%      Default |
# +-------------------------------+----------------------+----------------------+

# 2. 安装 ultralytics(YOLOv8 官方库)
pip install ultralytics -i https://pypi.tuna.tsinghua.edu.cn/simple

# 3. 安装其他依赖
pip install albumentations opencv-python-headless -i https://pypi.tuna.tsinghua.edu.cn/simple

# 4. 验证安装
python -c "from ultralytics import YOLO; print('YOLOv8 ready!')"

第三章:数据集准备——生物多样性数据的艺术

3.1 数据来源

训练目标检测模型,最重要的是数据。我使用了多个数据源:

数据源内容数量
iNaturalist开源自然观察数据集~5000 张
Open ImagesGoogle 开源数据集(筛选动物类别)~3000 张
Kaggle 数据集鸟类、鱼类专项数据集~2000 张
自行标注网络图片 + LabelImg 标注~1000 张

最终数据集构成:

biodiversity_dataset/
├── images/
│   ├── train/          # 训练集 ~8800 张
│   ├── val/            # 验证集 ~1100 张
│   └── test/           # 测试集 ~1100 张
├── labels/
│   ├── train/          # 对应标注文件
│   ├── val/
│   └── test/
└── biodiversity.yaml   # 数据集配置文件

3.2 类别定义

根据项目需求,我定义了 20 个类别

# biodiversity.yaml
path: /root/autodl-tmp/biodiversity_dataset
train: images/train
val: images/val
test: images/test

# 类别定义
names:
  # 鸟类 (0-7)
  0: sparrow          # 麻雀
  1: magpie           # 喜鹊
  2: egret            # 白鹭
  3: kingfisher       # 翠鸟
  4: crow             # 乌鸦
  5: pigeon           # 鸽子
  6: swallow          # 燕子
  7: parrot           # 鹦鹉
  
  # 鱼类 (8-13)
  8: goldfish         # 金鱼
  9: koi              # 锦鲤
  10: tropical_fish   # 热带鱼
  11: catfish         # 鲶鱼
  12: bass            # 鲈鱼
  13: salmon          # 三文鱼
  
  # 小动物 (14-19)
  14: squirrel        # 松鼠
  15: rabbit          # 兔子
  16: hedgehog        # 刺猬
  17: lizard          # 蜥蜴
  18: frog            # 青蛙
  19: turtle          # 乌龟

nc: 20  # 类别数量

3.3 数据标注格式

YOLO 使用的是归一化的 xywh 格式:

# labels/train/img_001.txt
# 格式:class_id center_x center_y width height
# 所有值都是相对于图片尺寸的归一化值 (0-1)

0 0.456 0.523 0.124 0.089
14 0.721 0.634 0.098 0.156

3.4 数据预处理脚本

# scripts/prepare_dataset.py
"""
数据集预处理脚本
1. 统一图片格式
2. 检查标注文件
3. 划分训练/验证/测试集
4. 数据增强(可选)
"""

import os
import shutil
import random
from pathlib import Path
from PIL import Image
import yaml

def check_image(img_path):
    """检查图片是否有效"""
    try:
        img = Image.open(img_path)
        img.verify()
        return True
    except:
        return False

def check_label(label_path, num_classes=20):
    """检查标注文件格式是否正确"""
    if not os.path.exists(label_path):
        return False
    
    with open(label_path, 'r') as f:
        lines = f.readlines()
    
    for line in lines:
        parts = line.strip().split()
        if len(parts) != 5:
            return False
        
        class_id = int(parts[0])
        if class_id < 0 or class_id >= num_classes:
            return False
        
        # 检查坐标是否在 0-1 范围内
        for val in parts[1:]:
            v = float(val)
            if v < 0 or v > 1:
                return False
    
    return True

def split_dataset(source_dir, output_dir, train_ratio=0.8, val_ratio=0.1):
    """划分数据集"""
    images_dir = Path(source_dir) / 'images'
    labels_dir = Path(source_dir) / 'labels'
    
    # 获取所有有效的图片-标注对
    valid_pairs = []
    for img_path in images_dir.glob('*.*'):
        if img_path.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']:
            label_path = labels_dir / (img_path.stem + '.txt')
            if check_image(img_path) and check_label(label_path):
                valid_pairs.append((img_path, label_path))
    
    print(f"有效数据对: {len(valid_pairs)}")
    
    # 随机打乱
    random.shuffle(valid_pairs)
    
    # 计算划分点
    n = len(valid_pairs)
    train_end = int(n * train_ratio)
    val_end = int(n * (train_ratio + val_ratio))
    
    splits = {
        'train': valid_pairs[:train_end],
        'val': valid_pairs[train_end:val_end],
        'test': valid_pairs[val_end:]
    }
    
    # 创建目录并复制文件
    output_path = Path(output_dir)
    for split_name, pairs in splits.items():
        img_dir = output_path / 'images' / split_name
        lbl_dir = output_path / 'labels' / split_name
        img_dir.mkdir(parents=True, exist_ok=True)
        lbl_dir.mkdir(parents=True, exist_ok=True)
        
        for img_path, label_path in pairs:
            shutil.copy(img_path, img_dir / img_path.name)
            shutil.copy(label_path, lbl_dir / label_path.name)
        
        print(f"{split_name}: {len(pairs)} 张")
    
    return splits

def analyze_dataset(dataset_dir, yaml_path):
    """分析数据集统计信息"""
    with open(yaml_path, 'r') as f:
        config = yaml.safe_load(f)
    
    class_names = config['names']
    class_counts = {name: 0 for name in class_names.values()}
    
    labels_dir = Path(dataset_dir) / 'labels' / 'train'
    for label_file in labels_dir.glob('*.txt'):
        with open(label_file, 'r') as f:
            for line in f:
                class_id = int(line.strip().split()[0])
                class_name = class_names[class_id]
                class_counts[class_name] += 1
    
    print("\n=== 数据集类别分布 ===")
    for name, count in sorted(class_counts.items(), key=lambda x: -x[1]):
        bar = '█' * (count // 100)
        print(f"{name:15s}: {count:5d} {bar}")
    
    return class_counts

if __name__ == '__main__':
    # 划分数据集
    split_dataset(
        source_dir='/root/autodl-tmp/raw_data',
        output_dir='/root/autodl-tmp/biodiversity_dataset',
        train_ratio=0.8,
        val_ratio=0.1
    )
    
    # 分析数据集
    analyze_dataset(
        dataset_dir='/root/autodl-tmp/biodiversity_dataset',
        yaml_path='/root/autodl-tmp/biodiversity_dataset/biodiversity.yaml'
    )

3.5 数据增强策略

YOLO 内置了丰富的数据增强,但我们也可以自定义:

# 在训练配置中设置数据增强参数
augmentation_config = {
    'hsv_h': 0.015,      # 色调变化
    'hsv_s': 0.7,        # 饱和度变化
    'hsv_v': 0.4,        # 亮度变化
    'degrees': 10,       # 旋转角度
    'translate': 0.1,    # 平移
    'scale': 0.5,        # 缩放
    'shear': 5,          # 剪切
    'perspective': 0.0,  # 透视变换
    'flipud': 0.5,       # 上下翻转概率
    'fliplr': 0.5,       # 左右翻转概率
    'mosaic': 1.0,       # Mosaic 增强概率
    'mixup': 0.1,        # MixUp 增强概率
    'copy_paste': 0.1,   # Copy-Paste 增强概率
}

第四章:模型训练——4090 的狂喜

4.1 训练脚本

# train.py
"""
生物多样性目标检测模型训练脚本
使用 YOLOv8 在 RTX 4090 上训练
"""

import os
import torch
from ultralytics import YOLO
from datetime import datetime

def train():
    # ==================== 配置 ====================
    # 数据集配置
    DATA_YAML = '/root/autodl-tmp/biodiversity_dataset/biodiversity.yaml'
    
    # 模型选择(按需选择)
    # yolov8n.pt - Nano (最快,精度较低)
    # yolov8s.pt - Small (平衡)
    # yolov8m.pt - Medium (推荐)
    # yolov8l.pt - Large (精度高,较慢)
    # yolov8x.pt - XLarge (最高精度)
    MODEL = 'yolov8m.pt'
    
    # 训练超参数
    EPOCHS = 150
    BATCH_SIZE = 32          # 4090 24GB 显存可以用 32
    IMG_SIZE = 640
    WORKERS = 8
    
    # 优化器配置
    OPTIMIZER = 'AdamW'
    LR0 = 0.001              # 初始学习率
    LRF = 0.01               # 最终学习率因子
    MOMENTUM = 0.937
    WEIGHT_DECAY = 0.0005
    
    # 其他配置
    PROJECT = 'runs/biodiversity'
    NAME = f'yolov8m_{datetime.now().strftime("%Y%m%d_%H%M%S")}'
    
    # ==================== 环境检查 ====================
    print("=" * 50)
    print("环境检查")
    print("=" * 50)
    print(f"PyTorch 版本: {torch.__version__}")
    print(f"CUDA 可用: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"CUDA 版本: {torch.version.cuda}")
        print(f"GPU 设备: {torch.cuda.get_device_name(0)}")
        print(f"GPU 显存: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
    print("=" * 50)
    
    # ==================== 加载模型 ====================
    print(f"\n加载预训练模型: {MODEL}")
    model = YOLO(MODEL)
    
    # ==================== 开始训练 ====================
    print(f"\n开始训练...")
    print(f"数据集: {DATA_YAML}")
    print(f"Epochs: {EPOCHS}")
    print(f"Batch Size: {BATCH_SIZE}")
    print(f"Image Size: {IMG_SIZE}")
    
    results = model.train(
        # 数据配置
        data=DATA_YAML,
        epochs=EPOCHS,
        batch=BATCH_SIZE,
        imgsz=IMG_SIZE,
        
        # 设备配置
        device=0,              # 使用第一块 GPU
        workers=WORKERS,
        
        # 优化器配置
        optimizer=OPTIMIZER,
        lr0=LR0,
        lrf=LRF,
        momentum=MOMENTUM,
        weight_decay=WEIGHT_DECAY,
        
        # 学习率调度
        cos_lr=True,           # 余弦退火学习率
        warmup_epochs=3,       # 预热 epoch 数
        warmup_momentum=0.8,
        warmup_bias_lr=0.1,
        
        # 数据增强
        hsv_h=0.015,
        hsv_s=0.7,
        hsv_v=0.4,
        degrees=10,
        translate=0.1,
        scale=0.5,
        shear=5,
        flipud=0.5,
        fliplr=0.5,
        mosaic=1.0,
        mixup=0.1,
        
        # 损失函数权重
        box=7.5,               # 边界框损失权重
        cls=0.5,               # 分类损失权重
        dfl=1.5,               # DFL 损失权重
        
        # 正则化
        dropout=0.0,
        
        # 保存配置
        project=PROJECT,
        name=NAME,
        save=True,
        save_period=10,        # 每 10 个 epoch 保存一次
        
        # 日志配置
        verbose=True,
        plots=True,            # 生成训练曲线图
        
        # 早停配置
        patience=30,           # 30 个 epoch 无改善则停止
        
        # 其他
        seed=42,
        deterministic=True,
        amp=True,              # 混合精度训练,加速!
        resume=False,          # 是否从断点继续
    )
    
    print("\n" + "=" * 50)
    print("训练完成!")
    print("=" * 50)
    print(f"最佳模型保存在: {PROJECT}/{NAME}/weights/best.pt")
    
    return results

if __name__ == '__main__':
    train()

4.2 训练过程监控

# monitor.py
"""
训练过程监控脚本
实时查看训练状态
"""

import os
import time
import subprocess

def monitor_gpu():
    """监控 GPU 使用情况"""
    while True:
        os.system('clear')
        print("=" * 60)
        print("GPU 监控 (按 Ctrl+C 退出)")
        print("=" * 60)
        
        # GPU 状态
        result = subprocess.run(['nvidia-smi'], capture_output=True, text=True)
        print(result.stdout)
        
        # 最新日志
        log_files = sorted(
            [f for f in os.listdir('runs/biodiversity') if os.path.isdir(f'runs/biodiversity/{f}')],
            reverse=True
        )
        
        if log_files:
            latest_run = log_files[0]
            results_file = f'runs/biodiversity/{latest_run}/results.csv'
            if os.path.exists(results_file):
                print("\n最新训练结果:")
                os.system(f'tail -5 {results_file}')
        
        time.sleep(5)

if __name__ == '__main__':
    try:
        monitor_gpu()
    except KeyboardInterrupt:
        print("\n监控已停止")

4.3 4090 vs MacBook 训练对比

实测数据:

指标MacBook M4 ProRTX 4090
模型YOLOv8n (被迫)YOLOv8m
Batch Size8 (最大)32
每 Epoch 时间~15 min~45 sec
100 Epochs 总时间~25 h~1.25 h
显存占用OOM~18 GB
最终 mAP~0.45~0.72

结论:4090 的训练速度是 MacBook 的 20 倍以上,而且可以训练更大的模型!


第五章:模型评估与优化

5.1 验证脚本

# evaluate.py
"""
模型评估脚本
"""

from ultralytics import YOLO
import matplotlib.pyplot as plt
from pathlib import Path

def evaluate_model(model_path, data_yaml, save_dir='evaluation_results'):
    """
    全面评估模型性能
    """
    print(f"加载模型: {model_path}")
    model = YOLO(model_path)
    
    # 创建保存目录
    save_path = Path(save_dir)
    save_path.mkdir(exist_ok=True)
    
    # 在验证集上评估
    print("\n在验证集上评估...")
    val_results = model.val(
        data=data_yaml,
        split='val',
        batch=32,
        imgsz=640,
        conf=0.25,
        iou=0.6,
        plots=True,
        save_json=True,
    )
    
    # 打印关键指标
    print("\n" + "=" * 50)
    print("评估结果")
    print("=" * 50)
    print(f"mAP50:      {val_results.box.map50:.4f}")
    print(f"mAP50-95:   {val_results.box.map:.4f}")
    print(f"Precision:  {val_results.box.mp:.4f}")
    print(f"Recall:     {val_results.box.mr:.4f}")
    
    # 各类别 AP
    print("\n各类别 AP:")
    class_names = model.names
    for i, ap in enumerate(val_results.box.ap50):
        print(f"  {class_names[i]:15s}: {ap:.4f}")
    
    return val_results

def analyze_errors(model_path, test_images_dir, save_dir='error_analysis'):
    """
    分析预测错误
    """
    model = YOLO(model_path)
    save_path = Path(save_dir)
    save_path.mkdir(exist_ok=True)
    
    # 预测测试图片
    results = model.predict(
        source=test_images_dir,
        conf=0.25,
        save=True,
        save_txt=True,
        project=save_dir,
        name='predictions'
    )
    
    print(f"预测结果已保存到: {save_dir}/predictions")
    
    return results

def export_metrics_report(val_results, output_path='metrics_report.md'):
    """
    导出详细评估报告
    """
    report = f"""# 模型评估报告

## 总体指标

| 指标 | 值 |
|-----|-----|
| mAP@0.5 | {val_results.box.map50:.4f} |
| mAP@0.5:0.95 | {val_results.box.map:.4f} |
| Precision | {val_results.box.mp:.4f} |
| Recall | {val_results.box.mr:.4f} |

## 各类别性能

| 类别 | AP@0.5 | AP@0.5:0.95 |
|-----|-------|------------|
"""
    
    for i, (ap50, ap) in enumerate(zip(val_results.box.ap50, val_results.box.ap)):
        class_name = val_results.names[i]
        report += f"| {class_name} | {ap50:.4f} | {ap:.4f} |\n"
    
    with open(output_path, 'w') as f:
        f.write(report)
    
    print(f"报告已保存到: {output_path}")

if __name__ == '__main__':
    MODEL_PATH = 'runs/biodiversity/yolov8m_best/weights/best.pt'
    DATA_YAML = '/root/autodl-tmp/biodiversity_dataset/biodiversity.yaml'
    
    # 评估模型
    results = evaluate_model(MODEL_PATH, DATA_YAML)
    
    # 导出报告
    export_metrics_report(results)

5.2 我的训练结果

经过 150 个 epoch 的训练,最终模型表现:

=== 总体指标 ===
mAP@0.5:      0.724
mAP@0.5:0.95: 0.518
Precision:    0.731
Recall:       0.689

=== 各类别 AP@0.5 ===
鸟类:
  sparrow:      0.812
  magpie:       0.798
  egret:        0.856
  kingfisher:   0.723
  crow:         0.834
  pigeon:       0.891
  swallow:      0.654
  parrot:       0.789

鱼类:
  goldfish:     0.756
  koi:          0.812
  tropical_fish: 0.634
  catfish:      0.589
  bass:         0.623
  salmon:       0.712

小动物:
  squirrel:     0.834
  rabbit:       0.867
  hedgehog:     0.723
  lizard:       0.598
  frog:         0.645
  turtle:       0.789

5.3 性能优化技巧

训练过程中我尝试了多种优化方法:

1. 学习率调优

# 使用学习率查找器
from ultralytics import YOLO

model = YOLO('yolov8m.pt')

# 学习率范围测试
# 可以通过观察 loss 曲线找到最佳学习率
# 一般选择 loss 下降最快的点的 1/10

2. 数据平衡

针对类别不平衡问题:

# 方法1:过采样少数类
# 方法2:欠采样多数类
# 方法3:使用 class weights

# 在 YOLO 中可以通过调整 cls 损失权重来缓解
# 或者使用 Focal Loss

3. 模型融合

# 训练多个模型,推理时融合结果
from ultralytics import YOLO

model1 = YOLO('runs/exp1/weights/best.pt')
model2 = YOLO('runs/exp2/weights/best.pt')

# 简单方法:取多个模型的平均置信度
# 高级方法:使用 WBF (Weighted Boxes Fusion)

4. TTA(测试时增强)

# 推理时使用数据增强
results = model.predict(
    source='test_image.jpg',
    augment=True  # 开启 TTA
)

第六章:模型部署

6.1 模型导出

# export.py
"""
模型导出脚本
支持多种格式
"""

from ultralytics import YOLO

def export_model(model_path):
    model = YOLO(model_path)
    
    # 导出 ONNX(通用格式,推荐)
    model.export(
        format='onnx',
        imgsz=640,
        simplify=True,
        opset=12,
        dynamic=False,  # 固定输入尺寸
    )
    print("ONNX 模型已导出")
    
    # 导出 TensorRT(NVIDIA GPU 推理最快)
    model.export(
        format='engine',
        imgsz=640,
        device=0,
        half=True,  # FP16 加速
    )
    print("TensorRT 模型已导出")
    
    # 导出 CoreML(iOS/macOS 部署)
    model.export(
        format='coreml',
        imgsz=640,
        nms=True,
    )
    print("CoreML 模型已导出")
    
    # 导出 TFLite(移动端/边缘设备)
    model.export(
        format='tflite',
        imgsz=640,
        int8=True,  # INT8 量化
    )
    print("TFLite 模型已导出")

if __name__ == '__main__':
    export_model('runs/biodiversity/yolov8m_best/weights/best.pt')

6.2 推理脚本

# inference.py
"""
推理脚本
支持图片、视频、摄像头
"""

import cv2
from ultralytics import YOLO
from pathlib import Path

class BiodiversityDetector:
    def __init__(self, model_path, conf_threshold=0.25):
        self.model = YOLO(model_path)
        self.conf_threshold = conf_threshold
        
        # 类别颜色(BGR 格式)
        self.colors = {
            'bird': (0, 255, 0),      # 绿色
            'fish': (255, 128, 0),    # 蓝色
            'animal': (0, 128, 255),  # 橙色
        }
        
        # 类别分组
        self.bird_classes = ['sparrow', 'magpie', 'egret', 'kingfisher', 
                             'crow', 'pigeon', 'swallow', 'parrot']
        self.fish_classes = ['goldfish', 'koi', 'tropical_fish', 
                             'catfish', 'bass', 'salmon']
        self.animal_classes = ['squirrel', 'rabbit', 'hedgehog', 
                               'lizard', 'frog', 'turtle']
    
    def get_color(self, class_name):
        """根据类别返回颜色"""
        if class_name in self.bird_classes:
            return self.colors['bird']
        elif class_name in self.fish_classes:
            return self.colors['fish']
        else:
            return self.colors['animal']
    
    def detect_image(self, image_path, save_path=None):
        """检测单张图片"""
        results = self.model.predict(
            source=image_path,
            conf=self.conf_threshold,
            verbose=False
        )[0]
        
        # 绘制结果
        img = cv2.imread(image_path)
        annotated = self.draw_results(img, results)
        
        if save_path:
            cv2.imwrite(save_path, annotated)
            print(f"结果已保存到: {save_path}")
        
        return results, annotated
    
    def detect_video(self, video_path, output_path=None):
        """检测视频"""
        cap = cv2.VideoCapture(video_path)
        
        if output_path:
            fourcc = cv2.VideoWriter_fourcc(*'mp4v')
            fps = int(cap.get(cv2.CAP_PROP_FPS))
            width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
            height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
            out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
        
        frame_count = 0
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
            
            # 检测
            results = self.model.predict(
                source=frame,
                conf=self.conf_threshold,
                verbose=False
            )[0]
            
            # 绘制
            annotated = self.draw_results(frame, results)
            
            if output_path:
                out.write(annotated)
            
            frame_count += 1
            if frame_count % 30 == 0:
                print(f"已处理 {frame_count} 帧")
        
        cap.release()
        if output_path:
            out.release()
            print(f"视频已保存到: {output_path}")
    
    def detect_camera(self, camera_id=0):
        """实时摄像头检测"""
        cap = cv2.VideoCapture(camera_id)
        
        print("按 'q' 退出")
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            
            results = self.model.predict(
                source=frame,
                conf=self.conf_threshold,
                verbose=False
            )[0]
            
            annotated = self.draw_results(frame, results)
            
            cv2.imshow('Biodiversity Detection', annotated)
            
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break
        
        cap.release()
        cv2.destroyAllWindows()
    
    def draw_results(self, img, results):
        """绘制检测结果"""
        annotated = img.copy()
        
        for box in results.boxes:
            # 获取边界框
            x1, y1, x2, y2 = map(int, box.xyxy[0])
            conf = float(box.conf[0])
            cls_id = int(box.cls[0])
            cls_name = results.names[cls_id]
            
            # 获取颜色
            color = self.get_color(cls_name)
            
            # 绘制边界框
            cv2.rectangle(annotated, (x1, y1), (x2, y2), color, 2)
            
            # 绘制标签背景
            label = f'{cls_name} {conf:.2f}'
            (label_w, label_h), _ = cv2.getTextSize(
                label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 1
            )
            cv2.rectangle(
                annotated, 
                (x1, y1 - label_h - 10), 
                (x1 + label_w, y1), 
                color, -1
            )
            
            # 绘制标签文字
            cv2.putText(
                annotated, label, (x1, y1 - 5),
                cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1
            )
        
        # 添加统计信息
        stats = self.get_detection_stats(results)
        y_offset = 30
        for category, count in stats.items():
            text = f'{category}: {count}'
            cv2.putText(
                annotated, text, (10, y_offset),
                cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2
            )
            y_offset += 30
        
        return annotated
    
    def get_detection_stats(self, results):
        """统计检测结果"""
        stats = {'Birds': 0, 'Fish': 0, 'Animals': 0}
        
        for box in results.boxes:
            cls_id = int(box.cls[0])
            cls_name = results.names[cls_id]
            
            if cls_name in self.bird_classes:
                stats['Birds'] += 1
            elif cls_name in self.fish_classes:
                stats['Fish'] += 1
            else:
                stats['Animals'] += 1
        
        return stats

# 使用示例
if __name__ == '__main__':
    detector = BiodiversityDetector(
        model_path='runs/biodiversity/yolov8m_best/weights/best.pt',
        conf_threshold=0.3
    )
    
    # 检测图片
    detector.detect_image('test_images/bird.jpg', 'output/bird_result.jpg')
    
    # 检测视频
    # detector.detect_video('test_video.mp4', 'output/result.mp4')
    
    # 实时检测
    # detector.detect_camera()

6.3 API 服务

# api_server.py
"""
FastAPI 推理服务
"""

from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
import cv2
import numpy as np
from ultralytics import YOLO
import io
from PIL import Image

app = FastAPI(title="Biodiversity Detection API")

# 加载模型(全局单例)
model = YOLO('runs/biodiversity/yolov8m_best/weights/best.pt')

@app.post("/detect")
async def detect(file: UploadFile = File(...), confidence: float = 0.25):
    """
    检测上传图片中的生物
    """
    # 读取图片
    contents = await file.read()
    image = Image.open(io.BytesIO(contents))
    
    # 推理
    results = model.predict(source=image, conf=confidence, verbose=False)[0]
    
    # 构建响应
    detections = []
    for box in results.boxes:
        detections.append({
            'class': results.names[int(box.cls[0])],
            'confidence': float(box.conf[0]),
            'bbox': {
                'x1': int(box.xyxy[0][0]),
                'y1': int(box.xyxy[0][1]),
                'x2': int(box.xyxy[0][2]),
                'y2': int(box.xyxy[0][3]),
            }
        })
    
    return JSONResponse({
        'success': True,
        'count': len(detections),
        'detections': detections
    })

@app.get("/health")
async def health():
    return {"status": "healthy"}

# 启动命令: uvicorn api_server:app --host 0.0.0.0 --port 8000

第七章:经验总结与踩坑记录

7.1 踩过的坑

坑 1:数据集路径问题

# 错误:使用相对路径
data: dataset/biodiversity.yaml

# 正确:使用绝对路径
data: /root/autodl-tmp/biodiversity_dataset/biodiversity.yaml

坑 2:显存不足

# 现象:CUDA out of memory

# 解决方案:
# 1. 减小 batch_size
# 2. 减小图片尺寸
# 3. 使用混合精度训练 (amp=True)
# 4. 使用梯度累积
# 5. 清理显存
import torch
torch.cuda.empty_cache()

坑 3:训练不收敛

# 可能原因:
# 1. 学习率太大或太小
# 2. 数据标注错误
# 3. 数据增强太激进

# 排查方法:
# 1. 检查几张训练图片的标注是否正确
# 2. 先用小数据集快速迭代验证
# 3. 关闭数据增强看是否能收敛

坑 4:AutoDL 连接断开

# 使用 tmux 或 screen 保持会话
tmux new -s train
python train.py

# 断开后重新连接
tmux attach -t train

# 或者使用 nohup
nohup python train.py > train.log 2>&1 &

7.2 性能优化清单

✅ 使用预训练权重(迁移学习)
✅ 开启混合精度训练(AMP)
✅ 使用余弦退火学习率
✅ 适当的数据增强
✅ 使用 EMA(指数移动平均)
✅ 早停防止过拟合
✅ 多尺度训练
✅ 标签平滑

7.3 最终项目结构

biodiversity_detection/
├── data/
│   ├── raw/                    # 原始数据
│   └── biodiversity_dataset/   # 处理后的数据集
│       ├── images/
│       ├── labels/
│       └── biodiversity.yaml
├── scripts/
│   ├── prepare_dataset.py      # 数据准备
│   ├── train.py                # 训练脚本
│   ├── evaluate.py             # 评估脚本
│   ├── export.py               # 模型导出
│   └── inference.py            # 推理脚本
├── runs/                       # 训练输出
│   └── biodiversity/
│       └── yolov8m_best/
│           ├── weights/
│           │   ├── best.pt
│           │   └── last.pt
│           └── results.csv
├── deploy/
│   ├── api_server.py           # API 服务
│   ├── requirements.txt
│   └── Dockerfile
└── README.md

结语:从绝望到狂喜

回顾这个项目,我的心路历程是这样的:

  1. 天真期:"MacBook 应该够用吧?"
  2. 绝望期:"为什么这么慢?为什么 OOM?"
  3. 觉醒期:"原来深度学习真的需要好显卡..."
  4. 狂喜期:"4090 真香!一个 epoch 不到一分钟!"
  5. 满足期:"模型终于能认出我家小区的麻雀了!"

一些关键数字:

项目数值
总训练时间~4 小时
AutoDL 费用~10 元
最终 mAP@0.50.724
支持类别数20
模型大小~50 MB
推理速度~15ms/张 (4090)

如果你也有类似的项目需求,我的建议是:

  1. 不要在消费级设备上死磕训练——租云 GPU,性价比更高
  2. 数据质量比数量更重要——宁可少而精,不要多而乱
  3. 从小模型开始——先跑通流程,再追求性能
  4. 善用预训练权重——站在巨人的肩膀上
  5. 多看官方文档——Ultralytics 的文档写得很好

最后,希望这篇文章能帮助到同样想训练目标检测模型的你。

让 AI 认识这个美丽星球上的每一个生命,这件事本身就很酷。


项目代码已开源:GitHub - detech
评论区

偷个代码先

avatar