본문 바로가기

개발

Pytorch backward_hook에서 얻을 수 있는 gradient의 의미

Pytorch에는 gradient를 얻을 수 있는 api를 제공한다.

nn.Module 타입인 변수에 register_backward_hook(hook_function)을 해주면 된다.

이 때 hook_functionmodule, input_gradient, output_gradient를 인자로 받는다.

그런데 이렇게 설명만 보면 input_gradient, output_gradient라는 말이 약간 햇갈린다.

input, output이 forward pass를 기준으로 되어 있기 때문이고, 수식으로도 설명이 안 되어 있다.

 

그래서 도대체 두 개의 gradient가 뭘 의미하는 지 모르겠어서 예제를 통해 살펴보려고 한다.

아래와 같은 코드를 실행하면 module, input_gradient, output_gradient 값을 확인할 수 있다.

import torch
from torch import nn


def hook(module, input_gradient, output_gradient):
    print(module)
    print(input_gradient)
    print(output_gradient)


def forward_hook(module, input, output):
    print("forward hook")
    print(module)
    print(input)
    print(output)
    print("==========================")


class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.t = torch.tensor((3.0, 4.0))
        self.p = nn.Parameter(self.t)

    def forward(self, x):
        return self.p * x


class MyModule2(nn.Module):
    def __init__(self):
        super(MyModule2, self).__init__()
        self.l = nn.Linear(2, 1, bias=False)
        self.activation = nn.Sigmoid()

    def forward(self, x):
        return self.activation(self.l(x))


module = MyModule()
module.register_backward_hook(hook)
module2 = MyModule2()
module2.register_forward_hook(forward_hook)
module2.register_backward_hook(hook)

x = torch.tensor(5.0, requires_grad=True)
y = module(x)
z = module2(y)
z.backward()
print("module.p.grad: ", module.p.grad)
print("module2.l.weight:", module2.l.weight)
print("module2.l.weight.grad: ", module2.l.weight.grad)

# 결과:
"""
MyModule2(
  (l): Linear(in_features=2, out_features=1, bias=False)
  (activation): Sigmoid()
)
(tensor([15., 20.], grad_fn=<MulBackward0>),)
tensor([0.2502], grad_fn=<SigmoidBackward0>)
==========================
MyModule2(
  (l): Linear(in_features=2, out_features=1, bias=False)
  (activation): Sigmoid()
)
(tensor([0.1876]),)
(tensor([1.]),)
MyModule()
(tensor([-0.1397,  0.0533]), tensor(-0.0412))
(tensor([-0.0279,  0.0107]),)
module.p.grad:  tensor([-0.1397,  0.0533])
module2.l.weight: Parameter containing:
tensor([[-0.1490,  0.0568]], requires_grad=True)
module2.l.weight.grad:  tensor([[2.8137, 3.7516]])

"""

Network의 계산 과정은 다음과 같다.

y = p * x = (3.0, 4.0) * 5.0 = (15.0, 20.0) <- module2 forward_hook의 input

 

z = sigmoid(l * y) = sigmoid((-0.1490, 0.0568) * (15.0, 20.0)) = sigmoid(-0.1490 * 15 + 0.0568 * 20)

= sigmoid(-2.235 + 1.136) = sigmoid(-1.099) = 0.2502 <- module2 forward_hook의 output

※ l * y를 아래에서 쓰기 위해 t라 하겠다.

 

이를 토대로 backward 계산 과정을 수행해보면,

module2 backward_hook의 output_gradient는 loss function이 없으므로 1

module2 backward_hook의 input_gradient:

dz / dt = sigmoid 미분 = t * (1 - t) = 0.2502 * (1 - 0.2502) = 0.1876

이 부분이 제일 이해가 안 가는 부분이다. 왜 dz / dy 또는 dz / dl이 아니라 dz / dt인지...

module2 파라미터의 gradient:

dz / dl = (dz / dt) * (dt / dl) = 0.1876 * y = 0.1876 * (15.0, 20.0) = (2.8137, 3.7516)

 

module의 output_gradient:

dz / dy = (dz / dt ) * (dt / dy) = 0.1876 * l = 0.1876 * (-0.1490, 0.0568) = (-0.0279, 0.0107)

module의 input_gradient:

(1) dz / dp = (dz / dt) * (dt / dy) * (dy / dp) = (-0.0279, 0.0107) * x = (-0.1397, 0.0533)

(2) dz / dx = (dz / dt) * (dt / dy) * (dy / dx) = (-0.0279, 0.0107) * p = (-0.0279, 0.0107) * (3.0, 4.0) = -0.0412

 

이렇게 backward_hook이 뱉어내는 값이 정확히 어떤 의미인지 실제로 계산해보았다.

그런데 보통 parameter의 gradient를 구한다면 굳이 backward_hook을 써야 하나 싶다.

register_full_backward_hook()

그런데 register_backward_hook()을 쓰면 deprecated됐다는 말과 함께 register_full_backward_hook()을 써야 한다고 나온다. 그래서 위의 의문점이 해소되나 확인해볼 겸 register_backward_hook()register_full_backward_hook()으로 바꾼 후 실행해봤다.

import torch
from torch import nn


def hook(module, input_gradient, output_gradient):
    print(module)
    print(input_gradient)
    print(output_gradient)


def forward_hook(module, input, output):
    print("forward hook")
    print(module)
    print(input)
    print(output)
    print("==========================")


class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.t = torch.tensor((3.0, 4.0))
        self.p = nn.Parameter(self.t)

    def forward(self, x):
        return self.p * x


class MyModule2(nn.Module):
    def __init__(self):
        super(MyModule2, self).__init__()
        self.l = nn.Linear(2, 1, bias=False)
        self.activation = nn.Sigmoid()

    def forward(self, x):
        return self.activation(self.l(x))


module = MyModule()
module.register_full_backward_hook(hook) #수정
module2 = MyModule2()
module2.register_forward_hook(forward_hook)
module2.register_full_backward_hook(hook) #수정

x = torch.tensor(5.0, requires_grad=True)
y = module(x)
z = module2(y)
z.backward()
print("module.p.grad: ", module.p.grad)
print("module2.l.weight:", module2.l.weight)
print("module2.l.weight.grad: ", module2.l.weight.grad)

# 결과:
"""
forward hook
MyModule2(
  (l): Linear(in_features=2, out_features=1, bias=False)
  (activation): Sigmoid()
)
(tensor([15., 20.], grad_fn=<BackwardHookFunctionBackward>),)
tensor([0.5173], grad_fn=<SigmoidBackward0>)
==========================
MyModule2(
  (l): Linear(in_features=2, out_features=1, bias=False)
  (activation): Sigmoid()
)
(tensor([-0.0674,  0.0514]),)
(tensor([1.]),)
MyModule()
(tensor(0.0035),)
(tensor([-0.0674,  0.0514]),)
module.p.grad:  tensor([-0.3372,  0.2572])
module2.l.weight: Parameter containing:
tensor([[-0.2701,  0.2060]], requires_grad=True)
module2.l.weight.grad:  tensor([[3.7455, 4.9940]])
"""

다시 실행하면 module2의 l 값이 다시 initialize되어 값이 변했다.

 

출력을 확인해보면 위에서 생각한 의문점이 해소됐다.

module2input_gradient가 moduleoutput_gradient와 같아졌다.

결국 dz / dy를 계산한 값을 보여준 것이다.

 

또한 moduleinput_gradient도 달라졌다.

위에서는 dz / dp, dz / dx 둘 다 출력했는데, 지금은 dz / dx만 보여준다.

y와 x가 각각 module2module의 input(forward()의 인자)라서 이렇게 보여주는 게 맞는 것 같다.

dz / dp 는 p.grad로 확인하면 된다.

 

 

Embedding layer의 output gradient는 parameter의 gradient와 같다.

이는 embedding layer가 parameter의 값을 그대로 내보내는 특성 때문이다.