1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
| def fit_kfold(self, n_splits=5, random_seed=42): print(f"Start {n_splits}-fold cross validation training") self.train() args = self.cfg n_parameters = sum([p.numel() for p in self.model.parameters() if p.requires_grad]) print(f'number of trainable parameters: {n_parameters}') import copy initial_model_state = copy.deepcopy( dist_utils.de_parallel(self.model).state_dict() ) full_dataset = self.train_dataloader.dataset indices = list(range(len(full_dataset))) kf = KFold(n_splits=n_splits, shuffle=True, random_state=random_seed) fold_results = [] for fold, (train_idx, val_idx) in enumerate(kf.split(indices)): print(f"\n=== Fold {fold + 1}/{n_splits} ===") dist_utils.de_parallel(self.model).load_state_dict(initial_model_state) self.optimizer = self.cfg.optimizer self.lr_scheduler = self.cfg.lr_scheduler self.lr_warmup_scheduler = self.cfg.lr_warmup_scheduler self.last_epoch = -1 train_subset = _SubsetWithEpoch(full_dataset, train_idx.tolist()) val_subset = _SubsetWithEpoch(full_dataset, val_idx.tolist()) orig_train = self.train_dataloader orig_val = self.val_dataloader from ..data import DataLoader as RTDataLoader fold_train_loader = RTDataLoader( dataset=train_subset, batch_size=orig_train.batch_size, shuffle=True, num_workers=orig_train.num_workers, drop_last=orig_train.drop_last, collate_fn=orig_train.collate_fn, pin_memory=orig_train.pin_memory, ) fold_val_loader = RTDataLoader( dataset=val_subset, batch_size=orig_val.batch_size, shuffle=False, num_workers=orig_val.num_workers, drop_last=orig_val.drop_last, collate_fn=orig_val.collate_fn, pin_memory=orig_val.pin_memory, ) fold_train_loader = dist_utils.warp_loader(fold_train_loader, shuffle=True) fold_val_loader = dist_utils.warp_loader(fold_val_loader, shuffle=False) best_stat = {'epoch': -1} for epoch in range(args.epoches): fold_train_loader.set_epoch(epoch) if dist_utils.is_dist_available_and_initialized(): fold_train_loader.sampler.set_epoch(epoch) train_stats = train_one_epoch( self.model, self.criterion, fold_train_loader, self.optimizer, self.device, epoch, max_norm=args.clip_max_norm, print_freq=args.print_freq, ema=self.ema, scaler=self.scaler, lr_warmup_scheduler=self.lr_warmup_scheduler, writer=self.writer ) if self.lr_warmup_scheduler is None or self.lr_warmup_scheduler.finished(): self.lr_scheduler.step() self.last_epoch += 1 module = self.ema.module if self.ema else self.model test_stats, coco_evaluator = evaluate( module, self.criterion, self.postprocessor, fold_val_loader, self.evaluator, self.device ) for k in test_stats: if k in best_stat: best_stat['epoch'] = epoch if test_stats[k][0] > best_stat[k] else best_stat['epoch'] best_stat[k] = max(best_stat[k], test_stats[k][0]) else: best_stat['epoch'] = epoch best_stat[k] = test_stats[k][0] print(f'Fold {fold+1} Epoch {epoch} best_stat: {best_stat}') fold_results.append(best_stat) if self.output_dir: dist_utils.save_on_master( self.state_dict(), self.output_dir / f'fold_{fold+1}_best.pth' ) print("\n=== K-Fold Results ===") for i, r in enumerate(fold_results): print(f"Fold {i+1}: {r}") if self.output_dir and dist_utils.is_main_process(): with (self.output_dir / "kfold_results.txt").open("w") as f: for i, r in enumerate(fold_results): f.write(f"Fold {i+1}: {r}\n") return fold_results
|