Hook是什么?

Hook(钩子)其实并不是Pytorch特有的机制,其在软件工程中也是相当常见的,一般来说Hook表示一种自动触发的机制,即在遇到某些时间/情况之后会自动执行的事项,其实在生活中也会遇到很多Hook的事件:

  • 移动到光线变化的环境里,手机屏幕亮度会跟着变化
  • 水烧开后就会沸腾把壶盖顶开
  • 火灾情况下温度升高自动触发报警系统和灭火喷头

总而言之,虽然上面很多情况即便没有Hook,我们也能实现(比如手动调亮度、手动打开报警和灭火器等),但是Hook作为一种强大的自动触发机制,能够很大程度上帮助我们提高效率。

Pytorch中的 Hook是干嘛的?

当想要查看网络输出中每层特征的shape时,有没有过手动print每个tensor.shape的情况?虽然快但是不“优雅”而且很有可能导致代码显得冗余杂乱。这时候如果在网络前向过程中设置hook机制,就能自动打印张量的shape,并且不会影响原代码的功能和逻辑,这些被添加的小功能就像一个小钩子🪝一样“挂”在原代码逻辑上但是不会改变原逻辑。

在Pytorch中Hook能做的事情非常多:

  • 打印输出每层张量的shape
  • 查看或修改每层参数的梯度(比如进行梯度裁剪)
  • 可视化网络中间层的特征图
  • ….

在Pytorch中常用的可以给张量(Tensor)或者模型(Module)设置Hook:

针对Tensor和Module的hook函数签名如下:

from torch import nn, Tensor
def module_hook(module:nn.Module, input:Tensor, output:Tensor):
    # 接受module类型对象,及其输入输出,这里可以做尺寸的打印、梯度裁剪、特征提取等
    
def tensor_hook(grad:Tensor):
    # 接受tensor的梯度信息,这里也可以做尺寸的打印、梯度裁剪

Pytorch中Hook的应用

1.打印中间张量的信息

import torch
from torch import nn, Tensor

class DemoMoudule(nn.Module):
    def __init__(self):
        super().__init__()
        # 构建一个简单的DNN网络: 由两个卷积输入层、一个BN层、一个卷积输出层
        self.conv_in = nn.Sequential(nn.Conv2d(3, 2, 3, 2),
                                      nn.Conv2d(2, 1, 3, 2))
        self.bn = nn.BatchNorm2d(1)
        self.conv_out = nn.Conv2d(1, 20, 3, 1)

        # 为该网络的每个一级子module(即conv_in、bn和conv_out)注册前向hook,
        # 在forward时候会自动调用对应的函数(这里是打印该module输出层的名称、尺寸、均值)
        for name, layer in self.named_children():
            layer.__name__ = name
            layer.register_forward_hook(
                lambda l, _, output:
                print("{}:{},{}".format(l.__name__, output.shape, torch.mean(output)))
            )
            
    def forward(self, x: Tensor) -> Tensor:
        x = self.conv_in(x)
        x = self.bn(x)
        x = self.conv_out(x)
        return x

demo_model = DemoMoudule()
dummy_input = torch.ones(10, 3, 28, 28)
dummy_output = demo_model(dummy_input)
'''
输出:
conv_in:torch.Size([10, 1, 6, 6]),-0.3441038727760315
bn:torch.Size([10, 1, 6, 6]),0.0
conv_out:torch.Size([10, 20, 4, 4]),0.07038399577140808
'''

利用此技术,我们还可以针对已有的网络(比如ResNet50),在不修改该网络定义和源码的同时使用类似上面的一个封装,在前向过程中打印对应的张量尺寸(可参考zhihu.McGL)。

2.提取特征

import torch
from torch import nn, Tensor
from typing import Iterable, Dict

class FeatureExtreactor(nn.Module):
    def __init__(self, model: nn.Module, layer_names:Iterable[str]):
        super().__init__()
        self.model = model
        self.layer_names = layer_names
        self.__extraced_features = {}

        org_modules = dict([*self.model.named_modules()])
        for layer_name in layer_names:
            layer = org_modules[layer_name]
            layer.__name__ = layer_name
            # 将指定名称的Module输出添加到待返回的集合中
            layer.register_forward_hook(self.append_features)

    def append_features(self, layer, _, output_tensor):
        self.__extraced_features[layer.__name__] = output_tensor

    def forward(self, x: Tensor) -> Dict[str, Tensor]:
        _ = self.model(x)
        return self.__extraced_features

# 获取指定层的属性
feature_extreactor = FeatureExtreactor(demo_model, layer_names=["conv_in", "conv_out"])
dummy_input = torch.ones(10, 3, 28, 28)
demo_features = feature_extreactor(dummy_input)
for name, feature in demo_features.items():
    print("{}: {} {}".format(name, feature.shape, torch.mean(feature)))
    
'''
在使用之前DemoMoudule()类的情况下,输出:
conv_in: torch.Size([10, 1, 6, 6]) 0.05633455887436867
conv_out: torch.Size([10, 20, 4, 4]) 0.04933914169669151
'''

上面利用Module的前向hook可以根据指定的模块名获取网络指定的特征层,有了获取的特征层,我们能做的事情就非常多了,不仅仅获取shape、均值,此外也能进行特征图可视化等等。

而在深度学习中常常会要求计算图像的VGG特征也可以使用该方法推理获得。

3.梯度裁剪1

在未进行梯度裁剪的时候,我们打印demo_model网络最后一个卷积的前10个biase的梯度如下:

demo_model = DemoMoudule()
dummy_input = torch.ones(10, 3, 28, 28)
pred = demo_model(dummy_input)
loss = pred.log().mean()
loss.backward()
print(demo_model.conv_out.bias.grad)
'''
输出:
tensor([  0.7894,   1.0285,  -0.6203,  -0.1882,   0.6290,  -0.2002,   0.5751,
          0.1875,  -0.5093,   0.3338,  -1.4956,   0.2797,  -0.4018,  -0.1860,
        -12.1006,   0.2474,   1.7059,   0.1834,   0.3505,   0.3189])
'''

在使用Tensor形式的hook机制时,我们设定参数的梯度tensor的梯度在某个范围内,如下:

def gradient_clipper(model: nn.Module, clip_val:float)-> nn.Module:
    for parameter in model.parameters():
        # 对梯度tensor添加用于梯度截断的hook
        parameter.register_hook(lambda grad: grad.clamp_(-clip_val, clip_val))
    return model

clipped_model = gradient_clipper(demo_model, 0.01)
pred = clipped_model(dummy_input)
loss = pred.log().mean()
loss.backward()
print(clipped_model.conv_out.bias.grad)

'''
输出:
tensor([-0.0100, -0.0100,  0.0100, -0.0100, -0.0100, -0.0100,  0.0100, -0.0100,
        -0.0100, -0.0100,  0.0100, -0.0100, -0.0100,  0.0100,  0.0100,  0.0100,
        -0.0100,  0.0100, -0.0100, -0.0100])
可以看到梯度被限制到-0.01~0.01之间。
'''

3.梯度裁剪2

使用Moduel的register_backward_hook函数也能进行梯度裁剪:

def gradient_clipper2(model: nn.Module, clip_val:float)-> nn.Module:
    # grad_input元组包含(bias的梯度,输入x的梯度,权重weight的梯度),grad_output元组包含输出y的梯度。
    # 返回的是修改后的grad_input
    def back_hook(module, grad_input, grad_output):
        print('grad_input: ', grad_input)
        print('grad_output: ', grad_output)
        return grad_input[0].clamp(-clip_val, clip_val), \
               grad_input[1].clamp(-clip_val*2, clip_val*2), \
               grad_input[2].clamp(-clip_val*3, clip_val*3),
    for moduel in model.modules():
        moduel.register_backward_hook(back_hook)
    return model

# 因为grad_input是对输入x的梯度,所以要求x也是有梯度的,即要设定requires_grad=True
x = torch.tensor([[1., 2., 10.]], requires_grad=True)
module = gradient_clipper2(nn.Linear(3, 2), 0.001)
y = module(x)
y.mean().backward()
print('module_bias: {}, x:{} module_weight:{}'.
      format(module.bias.grad, x.grad, module.weight.grad))

'''
输出:
grad_input:  (tensor([0.5000, 0.5000]), tensor([[0.2492, 0.2174, 0.0614]]),  tensor([[0.5000, 0.5000],
        [1.0000, 1.0000],
        [5.0000, 5.0000]]))
grad_output:  (tensor([[0.5000, 0.5000]]),)
module_bias: tensor([0.0010, 0.0010]), x:tensor([[0.0020, 0.0020, 0.0020]]) module_weight:tensor([[0.0030, 0.0030, 0.0030], [0.0030, 0.0030, 0.0030]])
'''

总结

hook机制能够方便快捷地帮助我们做一些调试等辅助工作,同时也能保证代码的简洁性,其实除了上面的三种hook,pytorch还有register_full_backward_hook、register_forward_pre_hook等,但是比较常用的三种和对应的用法列在上面了,其他的用到时候再自己看后补充进来!

参考: