1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
|
def train(training_data_loader, validate_data_loader,start_epoch=0):
print('Start training...')
for epoch in range(start_epoch, epochs, 1):
epoch += 1
epoch_train_loss, epoch_val_loss = [], []
# ---- Epoch Train ----
model.train()
for iteration, batch in enumerate(training_data_loader, 1):
gt, lms, ms_hp, pan_hp = batch[0].to("mps"), batch[1].to("mps"), batch[2].to("mps"), batch[3].to("mps")
optimizer.zero_grad() # fixed
# ---- Model的最后计算模块 ----
hp_sr = model(ms_hp, pan_hp) # call model
sr = lms + hp_sr # output:= lms + hp_sr
loss = criterion(sr, gt) # compute loss
epoch_train_loss.append(loss.item()) # save all losses into a vector for one epoch
# --- Fixed Part ----
loss.backward() # fixed
optimizer.step() # fixed
for name, layer in model.named_parameters():
# writer.add_histogram('torch/'+name + '_grad_weight_decay', layer.grad, epoch*iteration)
writer.add_histogram('net/'+name + '_data_weight_decay', layer, epoch*iteration)
#lr_scheduler.step() # if update_lr, activate here!
t_loss = np.nanmean(np.array(epoch_train_loss)) # compute the mean value of all losses, as one epoch loss
writer.add_scalar('mse_loss/t_loss', t_loss, epoch) # write to tensorboard to check
print('Epoch: {}/{} training loss: {:.7f}'.format(epochs, epoch, t_loss)) # print loss for each epoch
if epoch % ckpt == 0: # if each ckpt epochs, then start to save model
save_checkpoint(model, epoch)
# ---- Epoch Validate ----
model.eval()# fixed
with torch.no_grad(): # fixed
for iteration, batch in enumerate(validate_data_loader, 1):
gt, lms, ms_hp, pan_hp = batch[0].to("mps"), batch[1].to("mps"), batch[2].to("mps"), batch[3].to("mps")
hp_sr = model(ms_hp, pan_hp)
sr = lms + hp_sr
loss = criterion(sr, gt)
epoch_val_loss.append(loss.item())
if epoch % 10 == 0:
v_loss = np.nanmean(np.array(epoch_val_loss))
writer.add_scalar('val/v_loss', v_loss, epoch)
print(' validate loss: {:.7f}'.format(v_loss))
writer.close() # close tensorboard
|