Skip to content

Commit 6418958

Browse files
committed
fix sampler bugs and update dataloader
1 parent fd0007e commit 6418958

File tree

3 files changed

+14
-12
lines changed

3 files changed

+14
-12
lines changed

train.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,12 @@
458458
gen.dataset.epoch_now = epoch
459459
gen_val.dataset.epoch_now = epoch
460460

461+
if distributed:
462+
train_sampler.set_epoch(epoch)
463+
461464
set_optimizer_lr(optimizer, lr_scheduler_func, epoch)
462465

463466
fit_one_epoch(model_train, model, yolo_loss, loss_history, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, UnFreeze_Epoch, Cuda, fp16, scaler, save_period, save_dir, local_rank)
467+
468+
if local_rank == 0:
469+
loss_history.writer.close()

utils/dataloader.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import cv2
44
import numpy as np
5+
import torch
56
from PIL import Image
67
from torch.utils.data.dataset import Dataset
78

@@ -354,5 +355,6 @@ def yolo_dataset_collate(batch):
354355
for img, box in batch:
355356
images.append(img)
356357
bboxes.append(box)
357-
images = np.array(images)
358-
return images, bboxes
358+
images = torch.from_numpy(np.array(images)).type(torch.FloatTensor)
359+
bboxes = [torch.from_numpy(ann).type(torch.FloatTensor) for ann in bboxes]
360+
return images, bboxes

utils/utils_fit.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,8 @@ def fit_one_epoch(model_train, model, yolo_loss, loss_history, optimizer, epoch,
2121
images, targets = batch[0], batch[1]
2222
with torch.no_grad():
2323
if cuda:
24-
images = torch.from_numpy(images).type(torch.FloatTensor).cuda()
25-
targets = [torch.from_numpy(ann).type(torch.FloatTensor).cuda() for ann in targets]
26-
else:
27-
images = torch.from_numpy(images).type(torch.FloatTensor)
28-
targets = [torch.from_numpy(ann).type(torch.FloatTensor) for ann in targets]
24+
images = images.cuda()
25+
targets = [ann.cuda() for ann in targets]
2926
#----------------------#
3027
# 清零梯度
3128
#----------------------#
@@ -94,11 +91,8 @@ def fit_one_epoch(model_train, model, yolo_loss, loss_history, optimizer, epoch,
9491
images, targets = batch[0], batch[1]
9592
with torch.no_grad():
9693
if cuda:
97-
images = torch.from_numpy(images).type(torch.FloatTensor).cuda()
98-
targets = [torch.from_numpy(ann).type(torch.FloatTensor).cuda() for ann in targets]
99-
else:
100-
images = torch.from_numpy(images).type(torch.FloatTensor)
101-
targets = [torch.from_numpy(ann).type(torch.FloatTensor) for ann in targets]
94+
images = images.cuda()
95+
targets = [ann.cuda() for ann in targets]
10296
#----------------------#
10397
# 清零梯度
10498
#----------------------#

0 commit comments

Comments
 (0)