本项目基于官方源RT-DETR代码修改而来,修改过程使用Deepwiki深度参与,具体的开源代码见Github仓库

改动过程

官方的RT-DETR改动起来其实并不容易,其代码很多的写作方法都很高级对于新手来讲不太友好,好在我们有大模型可以利用,以下是结合DeepWiki构建的修改步骤。

就思路而言可以采用倒序的方式观看,从train脚本的改动开始,我们的目的是:既然他相对较为封闭且改动起来不太容易,那我们就最小化改动的方式,通过命令行来判断是否需要进行K-fold,如果需要,则切换到K-fold的模式,那么我们需要改动的地方就是:

1
2
3
1、在train.py中创建训练模式切换的判断逻辑 

2、创建针对于k-fold的训练方法

从train.py中我们可以看到,DETR在训练时调用的是det_solver中DetSolver类的fit方法,那么我们可以在fit的基础上改进出一半fit_kfold的方法

修改 det_solver.py

在 rtdetrv2_pytorch/src/solver/det_solver.py 末尾添加 fit_kfold 方法:

1
2
3
4
5
6
7
8
9
10
import copy  
import numpy as np
from sklearn.model_selection import KFold
from torch.utils.data import Subset, DataLoader as TorchDataLoader

class _SubsetWithEpoch(Subset):
"""Subset wrapper that supports set_epoch (required by custom DataLoader)"""
def set_epoch(self, epoch):
if hasattr(self.dataset, 'set_epoch'):
self.dataset.set_epoch(epoch)

然后在 DetSolver 类中添加:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
def fit_kfold(self, n_splits=5, random_seed=42):  
print(f"Start {n_splits}-fold cross validation training")
self.train() # 初始化模型、优化器、dataloader等
args = self.cfg

n_parameters = sum([p.numel() for p in self.model.parameters() if p.requires_grad])
print(f'number of trainable parameters: {n_parameters}')

# 保存初始模型权重,每个fold重置
import copy
initial_model_state = copy.deepcopy(
dist_utils.de_parallel(self.model).state_dict()
)

# 获取完整训练数据集
full_dataset = self.train_dataloader.dataset
indices = list(range(len(full_dataset)))

kf = KFold(n_splits=n_splits, shuffle=True, random_state=random_seed)
fold_results = []

for fold, (train_idx, val_idx) in enumerate(kf.split(indices)):
print(f"\n=== Fold {fold + 1}/{n_splits} ===")

# 重置模型权重
dist_utils.de_parallel(self.model).load_state_dict(initial_model_state)
self.optimizer = self.cfg.optimizer
self.lr_scheduler = self.cfg.lr_scheduler
self.lr_warmup_scheduler = self.cfg.lr_warmup_scheduler
self.last_epoch = -1

# 创建子集(带set_epoch支持)
train_subset = _SubsetWithEpoch(full_dataset, train_idx.tolist())
val_subset = _SubsetWithEpoch(full_dataset, val_idx.tolist())

# 复用原dataloader的参数创建新loader
orig_train = self.train_dataloader
orig_val = self.val_dataloader

from ..data import DataLoader as RTDataLoader
fold_train_loader = RTDataLoader(
dataset=train_subset,
batch_size=orig_train.batch_size,
shuffle=True,
num_workers=orig_train.num_workers,
drop_last=orig_train.drop_last,
collate_fn=orig_train.collate_fn,
pin_memory=orig_train.pin_memory,
)
fold_val_loader = RTDataLoader(
dataset=val_subset,
batch_size=orig_val.batch_size,
shuffle=False,
num_workers=orig_val.num_workers,
drop_last=orig_val.drop_last,
collate_fn=orig_val.collate_fn,
pin_memory=orig_val.pin_memory,
)

fold_train_loader = dist_utils.warp_loader(fold_train_loader, shuffle=True)
fold_val_loader = dist_utils.warp_loader(fold_val_loader, shuffle=False)

# 训练当前fold
best_stat = {'epoch': -1}
for epoch in range(args.epoches):
fold_train_loader.set_epoch(epoch)
if dist_utils.is_dist_available_and_initialized():
fold_train_loader.sampler.set_epoch(epoch)

train_stats = train_one_epoch(
self.model, self.criterion, fold_train_loader,
self.optimizer, self.device, epoch,
max_norm=args.clip_max_norm,
print_freq=args.print_freq,
ema=self.ema,
scaler=self.scaler,
lr_warmup_scheduler=self.lr_warmup_scheduler,
writer=self.writer
)

if self.lr_warmup_scheduler is None or self.lr_warmup_scheduler.finished():
self.lr_scheduler.step()

self.last_epoch += 1

module = self.ema.module if self.ema else self.model
test_stats, coco_evaluator = evaluate(
module, self.criterion, self.postprocessor,
fold_val_loader, self.evaluator, self.device
)

for k in test_stats:
if k in best_stat:
best_stat['epoch'] = epoch if test_stats[k][0] > best_stat[k] else best_stat['epoch']
best_stat[k] = max(best_stat[k], test_stats[k][0])
else:
best_stat['epoch'] = epoch
best_stat[k] = test_stats[k][0]

print(f'Fold {fold+1} Epoch {epoch} best_stat: {best_stat}')

fold_results.append(best_stat)

# 保存每个fold的最佳模型
if self.output_dir:
dist_utils.save_on_master(
self.state_dict(),
self.output_dir / f'fold_{fold+1}_best.pth'
)

# 汇总结果
print("\n=== K-Fold Results ===")
for i, r in enumerate(fold_results):
print(f"Fold {i+1}: {r}")
if self.output_dir and dist_utils.is_main_process():
with (self.output_dir / "kfold_results.txt").open("w") as f:
for i, r in enumerate(fold_results):
f.write(f"Fold {i+1}: {r}\n")
return fold_results

修改 train.py

在 rtdetrv2_pytorch/tools/train.py 中:

1
2
3
4
5
6
7
8
9
10
11
# 在 main 函数中修改  
if args.test_only:
solver.val()
elif args.kfold:
solver.fit_kfold(n_splits=args.kfold_splits, random_seed=args.seed or 42)
else:
solver.fit()

# 在 argparse 部分添加
parser.add_argument('--kfold', action='store_true', default=False)
parser.add_argument('--kfold-splits', type=int, default=5)

使用方法

加入数据集

数据集只需要在原有的config文件中将训练和验证的数据集路径都换成新的数据集即可,新的数据集不需要区分训练和验证集

启动训练

在终端中启动:

1
2
3
4
5
6
7
8
# 默认交叉验证(5折)  
python tools/train.py -c configs/rtdetrv2/rtdetrv2_r50vd_6x_coco.yml --kfold

# 指定折数和随机种子
python tools/train.py -c configs/rtdetrv2/rtdetrv2_r50vd_6x_coco.yml --kfold --kfold-splits 5 --seed 42

# 加载预训练权重
python tools/train.py -c configs/rtdetrv2/rtdetrv2_r50vd_6x_coco.yml -r path/to/pretrained.pth --kfold