꺼내먹는지식 준

FC layer 를 1X1 Convolution 으로 바꾸기 본문

AI/CV

FC layer 를 1X1 Convolution 으로 바꾸기

알 수 없는 사용자 2022. 3. 12. 16:41
model = nn.Linear(10, 20)
model.weight.shape

# torch.Size([20, 10])

예상과 달리 nn.Linear(in,out) 반면, weight는 [20,10] 이다. 

 

이는 $y = xA^T + b$ 로 인해서 A(weight) 에 transpose 된 값이 들어가야 해서 그렇다. 

 

여기서 하고자 하는 task 는 다음과 같다. 

1) VGG backbone 이후를 1 X 1 conv 로 대체 

2) 1 X 1 conv 의 weight를 FC layer 의 weight로 대체 

super(VGG11Segmentation, self).__init__()

    self.backbone = VGG11BackBone()


    with torch.no_grad():
      self.conv_out = nn.Conv2d(512, num_classes, kernel = 1, stride=1)
      
  
    self.upsample = torch.nn.Upsample(scale_factor=16, mode='bilinear', align_corners=False)​

사실상 간단하게 backbone 을 가져온 후, fc layer 의 weight를 사용할 것임으로 torch.no_grad 를 선언한다. 

def copy_last_layer(self, fc_out):

    
    reshaped_fc_out = fc_out.weight.detach()
    reshaped_fc_out = torch.reshape(reshaped_fc_out, (7 , 512 , 1, 1))
    self.conv_out.weight = torch.nn.Parameter(reshaped_fc_out)
    
    assert self.conv_out.weight[0][0] == fc_out.weight[0][0]
    
    
    return self.conv_out

fc_out weight 의 shape:  7, 512 

 

detach() : 기존 Tensor에서 gradient 전파가 안되는 텐서 생성

단 storage를 공유하기에 detach로 생성한 Tensor가 변경되면 원본 Tensor도 똑같이 변한다. 

 

여기서 유의할 점은 1 X 1 conv의 weight 의 shape인데, 잘 보면 7, 512, 1, 1 이 유지된다. 

공식 문서를 살펴보면, 

라 한다. 

즉 out_channel, in_channel, kernel height, kernel width 이다.

 

그냥 deep copy를 안하고 torch.nn.Parameter로 감싸는 이유는 무엇일까?

출처: https://velog.io/@yh8109/PyTorch-2.-Model-%EB%A7%8C%EB%93%A4%EA%B8%B0-%EC%83%81

 

# Init segmentation network
modelS  = VGG11Segmentation()

# Copy the backbone of classification to segmentation backbone
modelS.backbone = modelC.backbone

# Copy the weights of the fc layer to 1x1 conv layer of segmentation network
fc_out = modelC.fc_out
modelS.copy_last_layer(fc_out)

그 후 다음과 같이 copy 해서 완료!

 

 

'AI > CV' 카테고리의 다른 글

Instance segmentation, Panoptic segmentation  (0) 2022.03.14
Object Detection 개괄 from selective search to SSD  (0) 2022.03.13
AutoGrad  (0) 2022.03.11
CNN Visualization  (0) 2022.03.11
Semantic segmentation  (0) 2022.03.10
Comments