3. Python代码

Important Modulos For Adaptative Convolution

Weights Initialization

  • 初始化Conv2d,BatchNorm2d, Linear
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# ---- Initialization ----

def init_weights(*modules):
    for module in modules:
        for m in module.modules():
            if isinstance(m, nn.Conv2d):   ## initialization for Conv2d

                variance_scaling_initializer(m.weight)  # method 1: initialization
                #nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')  # method 2: initialization
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.0)
            elif isinstance(m, nn.BatchNorm2d):   ## initialization for BN
                nn.init.constant_(m.weight, 1.0)
                nn.init.constant_(m.bias, 0.0)
            elif isinstance(m, nn.Linear):     ## initialization for nn.Linear
                # variance_scaling_initializer(m.weight)
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.0)

def variance_scaling_initializer(tensor):
    from scipy.stats import truncnorm

    def truncated_normal_(tensor, mean=0, std=1):
        with torch.no_grad():
            size = tensor.shape
            tmp = tensor.new_empty(size + (4,)).normal_()
            valid = (tmp < 2) & (tmp > -2)
            ind = valid.max(-1, keepdim=True)[1]
            tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
            tensor.data.mul_(std).add_(mean)
            return tensor

    def variance_scaling(x, scale=1.0, mode="fan_in", distribution="truncated_normal", seed=None):
        fan_in, fan_out = torch.nn.init._calculate_fan_in_and_fan_out(x)
        if mode == "fan_in":
            scale /= max(1., fan_in)
        elif mode == "fan_out":
            scale /= max(1., fan_out)
        else:
            scale /= max(1., (fan_in + fan_out) / 2.)
        if distribution == "normal" or distribution == "truncated_normal":
            # constant taken from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
            stddev = math.sqrt(scale) / .87962566103423978
        # print(fan_in,fan_out,scale,stddev)#100,100,0.01,0.1136
        truncated_normal_(x, 0.0, stddev)
        return x/10*1.28

    variance_scaling(tensor)

    return tensor

Summary

  • 需要修改input_sizebatch_size以及x
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
# ---- Summary ----
def summaries(model, writer=None, grad=False):
    model = model.to("mps")
    if grad:
        from torchsummary import summary
        summary(model, input_size=[(8, 16, 16), (1, 64, 64)], batch_size=1)
    else:
        for name, param in model.named_parameters():
            if param.requires_grad:
                print(name)

    if writer is not None:
        x = torch.randn(1, 64, 64, 64)
        writer.add_graph(model,(x,))

Predefine Jobs for Training

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# ------ Predefine ------

# ---- 1. Cuda (MPS) Settings----

SEED = 1
torch.manual_seed(SEED)
torch.mps.manual_seed(SEED)
"mps" = torch.device("mps")

# torch.cuda.manual_seed_all(SEED)
# cudnn.benchmark = True
# cudnn.deterministic = True
# cudnn.benchmark = False

# ---- 2. HYPER PARAMS ----

lr = 0.001
epochs = 450
ckpt = 50
batch_size = 32
model_path = "Weights/.pth"

# ---- 3. Load Model + Loss + Optimizer + Learn_rate_update ----

# Load Model
model = PanNet() # Model Name Needs Saving !!!!
if os.path.isfile(model_path):
    model.load_state_dict(torch.load(model_path))   ## Load the pretrained Encoder
    print('PANnet is Successfully Loaded from %s' % (model_path))

# Summary & Loss
summaries(model, grad=True)    ## Summary the Network
criterion = nn.MSELoss(reduction='mean')  ## Define the Loss function L2Loss

# Optimize & Update LR
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0)   # Optimizer 1: Adam
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.1)

# optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-7)  ## optimizer 2: SGD
# lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=180, gamma=0.1)  # learning-rate update: lr = lr* 1/gamma for each step_size = 180

# ---- 4. Tensorboard_show + Save_model ----
if os.path.exists('train_logs'):  # for tensorboard: copy dir of train_logs  ## Tensorboard_show: case 1
   shutil.rmtree('train_logs')  # ---> console (see tensorboard): tensorboard --logdir = dir of train_logs

writer = SummaryWriter('./train_logs/50')    ## Tensorboard_show: case 2

Training Function

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def train(training_data_loader, validate_data_loader,start_epoch=0):
    print('Start training...')

    for epoch in range(start_epoch, epochs, 1):

        epoch += 1
        epoch_train_loss, epoch_val_loss = [], []

        # ---- Epoch Train ----
        model.train()

        for iteration, batch in enumerate(training_data_loader, 1):
            gt, lms, ms_hp, pan_hp = batch[0].to("mps"), batch[1].to("mps"), batch[2].to("mps"), batch[3].to("mps")

            optimizer.zero_grad()  # fixed

            # ---- Model的最后计算模块 ----

            hp_sr = model(ms_hp, pan_hp)  # call model
            sr = lms + hp_sr  # output:= lms + hp_sr
            
            loss = criterion(sr, gt)  # compute loss
            epoch_train_loss.append(loss.item())  # save all losses into a vector for one epoch

            # --- Fixed Part ----

            loss.backward()   # fixed
            optimizer.step()  # fixed

            for name, layer in model.named_parameters():
                # writer.add_histogram('torch/'+name + '_grad_weight_decay', layer.grad, epoch*iteration)
                writer.add_histogram('net/'+name + '_data_weight_decay', layer, epoch*iteration)

        #lr_scheduler.step()  # if update_lr, activate here!

        t_loss = np.nanmean(np.array(epoch_train_loss))  # compute the mean value of all losses, as one epoch loss
        writer.add_scalar('mse_loss/t_loss', t_loss, epoch)  # write to tensorboard to check
        print('Epoch: {}/{} training loss: {:.7f}'.format(epochs, epoch, t_loss))  # print loss for each epoch

        if epoch % ckpt == 0:  # if each ckpt epochs, then start to save model
            save_checkpoint(model, epoch)

        # ---- Epoch Validate ----
        model.eval()# fixed
        with torch.no_grad():  # fixed
            for iteration, batch in enumerate(validate_data_loader, 1):
                gt, lms, ms_hp, pan_hp = batch[0].to("mps"), batch[1].to("mps"), batch[2].to("mps"), batch[3].to("mps")

                hp_sr = model(ms_hp, pan_hp)
                sr = lms + hp_sr

                loss = criterion(sr, gt)
                epoch_val_loss.append(loss.item())

        if epoch % 10 == 0:
            v_loss = np.nanmean(np.array(epoch_val_loss))
            writer.add_scalar('val/v_loss', v_loss, epoch)
            print('             validate loss: {:.7f}'.format(v_loss))

    writer.close()  # close tensorboard

Main Function of Training

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
# ------ Main Function ------

if __name__ == "__main__":
    train_set = Dataset_Pro('./training_data/train_small.h5')  # creat data for training   # 100
    training_data_loader = DataLoader(dataset=train_set, num_workers=0, batch_size=batch_size, shuffle=True
    pin_memory=True, drop_last=True)

    validate_set = Dataset_Pro('./training_data/valid_small.h5')
    validate_data_loader = DataLoader(dataset=validate_set, num_workers=0, batch_size=batch_size, shuffle=True,
                                      pin_memory=True, drop_last=True)

    train(training_data_loader, validate_data_loader)

Known Modulos

ResNet

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
# ---- ResBlock Definition ----
class Resblock(nn.Module):
    def __init__(self):
        super(Resblock, self).__init__()
        channel = 32
        self.conv1 = nn.Conv2d(in_channels=channel, out_channels=channel, kernel_size=3, stride=1, padding=1,bias=True, device="mps")
        self.conv2 = nn.Conv2d(in_channels=channel, out_channels=channel, kernel_size=3, stride=1, padding=1,bias=True, device="mps")
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):  # x= hp of ms; y = hp of pan
        x = x.to("mps")
        rs1 = self.relu(self.conv1(x))
        rs1 = self.conv2(rs1)
        rs = torch.add(x, rs1)
        return rs

Upsampling

  • 使用逆卷积实现上采样
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# ---- Upsampling definition ----
class Upsampling(nn.module):
    def __init__(self, upscale):
        self.upscale = upscale
        super(upsampling, self).__init__()
        channel = 32

        # upsampling
        self.upsp2 = nn.convtranspose2d(in_channels=channel, out_channels=channel, kernel_size=2, stride=2,padding=0, bias=true, device="mps")
        self.upsp3 = nn.convtranspose2d(in_channels=channel, out_channels=channel, kernel_size=5, stride=3,padding=1, bias=true, device="mps")
        self.upsp4 = nn.convtranspose2d(in_channels=channel, out_channels=channel, kernel_size=8, stride=4,padding=2, bias=true, device="mps")

        # conv & relu
        self.conv1 = nn.conv2d(in_channels=channel, out_channels=channel, kernel_size=3, stride=1, padding=1,bias=true, device="mps")
        self.conv2 = nn.conv2d(in_channels=channel, out_channels=channel, kernel_size=3, stride=1, padding=1,bias=true, device="mps")
        self.relu = nn.relu(inplace=true)

    def forward(self, x):
        x = x.to("mps")
        if self.upscale == 2:
            rs = self.upsp2(x)
        elif self.upscale == 3:
            rs = self.upsp3(x)
        elif self.upscale == 4:
            rs = self.upsp4(x)
        else:
            print("Wrong Upscaling Factor!")

        rs = self.relu(self.conv1(rs))
        rs = self.conv2(rs)

        return rs
Licensed under CC BY-NC-SA 4.0
comments powered by Disqus
Built with Hugo
Theme Stack designed by Jimmy