AI/Pytorch

nn.Module의 hook 과 apply

LTSGOD 2023. 3. 17. 11:41

hook

  PyTorch에서 hook은 nn.Module 클래스의 메소드 중 하나로, 모델의 특정 레이어에 훅(hook)을 등록하여 해당 레이어의 입력, 출력, 혹은 중간값을 추적할 수 있도록 합니다.

hook은 register_forward_hook 메소드와 register_backward_hook 메소드를 사용하여 등록할 수 있습니다. register_forward_hook 메소드는 해당 레이어의 순전파(forward) 연산이 실행되기 전에 호출되는 함수를 등록합니다. 반면, register_backward_hook 메소드는 해당 레이어의 역전파(backward) 연산이 실행되기 전에 호출되는 함수를 등록합니다.

 

hook 함수는 다음과 같은 인자를 가질 수 있습니다.

 

  • module: 훅이 등록된 nn.Module 객체
  • input: 해당 레이어의 입력
  • output: 해당 레이어의 출력 (순전파 훅의 경우)
  • grad_input: 해당 레이어의 입력의 기울기 (역전파 훅의 경우)
  • grad_output: 해당 레이어의 출력의 기울기 (역전파 훅의 경우)

  hook 함수에서는 이러한 인자를 이용하여 원하는 기능을 구현할 수 있습니다. 예를 들어, 중간 레이어의 출력값을 추적하여 디버깅에 사용하거나, 역전파 과정에서 특정 레이어의 입력의 기울기를 수정하는 등의 작업이 가능합니다.

def my_forward_hook(module, input, output):
    # 중간 레이어의 출력값을 출력
    print(f"{module}의 출력값: {output}")

model = nn.Sequential(
    nn.Linear(10, 5),
    nn.ReLU(),
    nn.Linear(5, 1)
)

# nn.ReLU() 레이어의 출력값을 추적하기 위한 훅 등록
relu_layer = model[1]
relu_layer.register_forward_hook(my_forward_hook)

 중간 중간에 개발자의 custom code를 등록할 수 있게 해주는 것 이라고 볼 수 있다.

 

 

Pytorch 에서의 hook 2가지

  • Tensor의 hook
    • backward hook 만 등록 가능
    • _backward_hooks 로 등록된 hook을 조회할 수 있다.
  • Module의 hook
    • backward, forward hook 다 등록 할 수 있다.
    • module.__dict__로 hook을 조회 할 수 있다.(hook 뿐만 아니라 parameter와 module의 변수들이 저장되어 있다.)

 

hook 등록하기

model.register_forward_hook(hook_function)

model.register_full_backward_hook(hook_function)

  register_full_backward_hook 함수는 모델의 모든 layer의 gradient에 대한 정보를 받아 처리할 수 있는 함수를 인자로 받습니다. 이 함수는 모델의 역전파 과정이 끝난 후에 호출되며, 모든 layer에 대한 gradient 값과 함께 호출됩니다. 이를 통해 각 layer에서의 gradient 값을 사용하여 모델의 학습 동작을 변경하거나, gradient를 수정할 수 있습니다.

hook 조회하기

tensor._backward_hooks

module.__dict__

 

 


모델 전체에 내가 만든 method를 적용시키고 싶을 때 -  apply

보통 nn.module의 내부 메서드들은 model 전체에 적용시켜준다 ex) .cpu를 model 뒤에 붙이면 알아서 모델 아래 존재하는 module들에 .cpu를 적용한다.

 

apply 가 module 들을 순회 할 때는 Postorder Traversal(후위 순회)를 사용한다. apply는 기존의 모델에 Parameter 추가, 초기화 , repr 수정, 동적 메서드 생성 등 다양한 일을 할 수 있기 때문에 쓰는 방법을 알아 두는 것이 좋을 것 같다.

 

후위 순회??

 

      1
     / \
    2   3
   / \   \
  4   5   6
  
  후위 순회 결과 4, 5, 2, 6, 3, 1

 

 

import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.linear1 = nn.Linear(10, 20)
        self.linear2 = nn.Linear(20, 30)

    def weight_init(self, layer):
        if isinstance(layer, nn.Linear):
            nn.init.xavier_uniform_(layer.weight)
            nn.init.zeros_(layer.bias)

    def initialize_weights(self):
        self.apply(self.weight_init)

model = MyModel()
model.initialize_weights()

  위의 코드에서 weight_init 메서드는 인자로 받은 레이어가 nn.Linear 클래스의 객체인 경우 해당 레이어의 가중치와 편향을 Xavier 초기화와 0으로 초기화합니다.

 

  initialize_weights 메서드에서는 apply 메서드를 사용하여 모델 객체의 모든 레이어에 대해 weight_init 메서드를 호출합니다. 이를 통해 모델 객체의 모든 nn.Linear 레이어에 대해 가중치 초기화를 수행할 수 있습니다.