본문 바로가기

개발

Python list와 일반 variable의 scope, gradient hook

학습 도중 계산된 gradient의 값을 hook을 이용해 따는 코드를 넣었다.

방법은 크게 어렵지 않다.

 

def hook_function(tensor_gradient):
    # 내가 하고 싶은 일
    return tensor_gradient

TORCH_TENSOR.register_hook(hook_function)

 

torch.Tensor는 backward gradient만 얻을 수 있다.

만약 Tensor 대신 Module이라면 다른 forward, backward hook을 모두 사용할 수 있다.

 

내가 얻고 싶은 것은  gradient의 누적 값이어서 일단 첫 번째 방법을 시도했다.

 

gradients = []
def hook_function(tensor_gradient):
    gradients.append(tensor_gradient)
    return tensor_gradient

TORCH_TENSOR.register_hook(hook_function)

 

이 방법은 잘 작동했으나, 학습이 길어질수록 gradients에 계속 tensor_gradient 값이 쌓이게 되어 메모리 낭비라고 생각했다.

 

최종적으로 원하는 건 모든 gradient의 합이어서 아래의 방법을 시도했다.

 

gradients = torch.zeros((shape1, shape2))
def hook_function(tensor_gradient):
    gradients += tensor_gradient
    return tensor_gradient

TORCH_TENSOR.register_hook(hook_function)

 

shape1, shape2tensor_gradient의 shape을 따라 설정했다.

이랬더니 에러가 났다.

 

궁금해서 구글링을 해보니

https://stackoverflow.com/questions/28819420/scope-of-lists-in-python-3

 

Scope of lists in Python 3

This may be a very stupid question, but I'm a little uncertain as to why lists behave differently to other variables in Python 3, with regard to scope. In the following code... foo = 1 bar = [1,2,...

stackoverflow.com

 

gradients += tensor_gradient 코드는

gradients = gradients + tensor_gradient인데,

두번째 gradients가 local 변수지만 undefined된 상태라 에러가 난다는 것 같다.

 

참고로 아래와 같은 코드는 2를 출력한다. 하지만 함수 바깥의 gradients 변수의 값은 변하지 않고 torch.Tensor 형태이다.

 

gradients = torch.zeros((shape1, shape2))
def hook_function(tensor_gradient):
    gradients = 2
    print(gradients)
    return tensor_gradient

TORCH_TENSOR.register_hook(hook_function)

 

그래서... gradients 변수를 list로 만들고 그 안에 Tensor를 넣었다.

 

gradients = [torch.zeros((shape1, shape2))]
def hook_function(tensor_gradient):
    gradients[0] += tensor_gradient
    return tensor_gradient

TORCH_TENSOR.register_hook(hook_function)

 

잘 동작하는 것 같다.