前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >以动制动 | Transformer 如何处理动态输入尺寸

以动制动 | Transformer 如何处理动态输入尺寸

作者头像
OpenMMLab 官方账号
发布2022-04-09 16:47:02
2K0
发布2022-04-09 16:47:02
举报
文章被收录于专栏:OpenMMLabOpenMMLab

从一个参数说起

在图像分类任务中,主干网络是视觉神经网络中进行图像特征提取的主体,常见的算法包括我们耳熟能详的 ResNet、Vision Transformer 等。

不知道大家是否注意到,用于图像分类的主干网络中,基于 CNN 结构的网络,通常不需要我们指定输入图像的尺寸,同时,同一个主干网络就能够处理各种尺寸的图像输入。但基于 Transformer 结构的主干网络,就往往需要我们在搭建网络时指定输入的图像尺寸参数 —— img_size,而且网络的前向推理输入也必须是符合这一尺寸的图像。

那么,为什么 Transformer 结构的网络中需要指定输入的图像尺寸呢?我们能否移除这个限制,让网络动态地支持各种尺寸的输入图像呢?这对于一些下游任务有重要的作用,也已经有了一些成熟的解决方案。在最新版的 MMClassification 中,我们将这一功能扩展到了各种基于 Transformer 结构的主干网络中,实现了分类任务与下游任务主干网络的统一。

接下来,就让我们了解一下,Transformer 结构网络支持动态输入尺寸的阻碍与解决方法。

“罪魁祸首”——位置编码

说起 Transformer 结构,大家最先想到的关键结构大概率是注意力模块,但这里,问题并不出在注意力模块中,因为注意力模块天然地支持动态尺寸输入。让我们看下这张经典的 ViT 结构图:

首先,我们会将输入图片按照一个固定的 patch size 切分成若干个 patch。之后每个图像 patch 经过一个线性映射得到对应的一个特征向量。这一个个特征向量如果按照其对应 patch 在图像上的位置排列,就是一张图像经过编码后的特征图,其长和宽分别等于原图在纵向和横向切分成了多少个 patch。之后,我们需要给这张特征图加上位置编码(position embedding),以体现每个 patch 在图像上的位置。

当输入图片尺寸发生变化时,由于每个 patch 的尺寸固定,图片切分出的 patch 数就会发生变化。表现在上述特征图中,就是特征图的尺寸发生了变化。这样一来,我们原本位置编码图的尺寸就和图像特征图的尺寸对不上了,无法进行后续的计算。

找到了问题所在,解决的方法也就顺理成章了。位置编码代表的是 patch 所在位置的附加信息,那么如果和图像特征图的尺寸不匹配,只需要使用双三次插值法(Bicubic)对位置编码图进行插值缩放,缩放到与图像特征图一致的尺寸,就同样可以表现每个 patch 在图片中的位置信息。

代码语言:javascript
复制
import torch
import torch.nn.functional as F

# 原始位置编码
pos_embed = torch.rand(1, 197, 64)
# 原始图像尺寸下,长和宽方向的 patch 数
src_shape = (14, 14)
# 输入图像尺寸下,长和宽方向的 patch 数
dst_shape = (16, 16)
# 额外编码数,在 ViT 中,为 1,指 class embedding;在 DeiT 中为 2
num_extra_tokens = 1

_, L, C = pos_embed.shape
src_h, src_w = src_shape
# 位置编码第二个维度大小应当等于 patch 数 + 额外编码数
assert L == src_h * src_w + num_extra_tokens

# 拆分额外编码和纯位置编码
extra_tokens = pos_embed[:, :num_extra_tokens]
src_weight = pos_embed[:, num_extra_tokens:]

# 将位置编码组织成 (1, C, H, W) 形式,其中 C 为通道数
src_weight = src_weight.reshape(1, src_h, src_w, C).permute(0, 3, 1, 2)
# 进行双三次插值
dst_weight = F.interpolate(src_weight, size=dst_shape, mode='bicubic')
# 重组位置编码为(1,H*W, C)形式,再拼接上额外编码,即获得新的位置编码
dst_weight = torch.flatten(dst_weight, 2).transpose(1, 2)
pos_embed = torch.cat((extra_tokens, dst_weight), dim=1)

在官方实现中,这一方法已经有了一些应用。不过源码只把这一方法用于微调模型时,加载和微调模型输入尺寸不同的预训练模型权重。而在 MMClassificaiton 中,我们将这一方法应用于每次模型的前向推理,使每次推理都可以应对不同尺寸的图像输入。

需要提醒的是,就像缩放照片会损失信息,这种对位置编码的插值也不是无损的,建议输入图像的尺度变化不要过大,同时需要在动态尺度输入下进行新的微调训练。

下面一个例子,展示了在 MMClassification 中使用 ViT 模型处理不同尺寸输入的流程:

代码语言:javascript
复制
import torch
from mmcls.models import build_backbone

cfg = dict(type='VisionTransformer', arch='base')
vit_model = build_backbone(cfg)

inputs = torch.rand(1, 3, 224, 224)
patch_embed, cls_token = vit_model(inputs)[-1]  # 获取模型最后一层输出
assert patch_embed.shape == (1, 768, 14, 14)

inputs = torch.rand(1, 3, 256, 384)
patch_embed, cls_token = vit_model(inputs)[-1]
assert patch_embed.shape == (1, 768, 16, 24)  # 输入尺寸不同,输出特征图的尺寸也不同

特殊的 Swin-Transformer

Position embedding 的问题遍布于经典 ViT 结构的主干网络中,但并不存在于 Swin-Transformer 中。对 Swin-Transformer 有了解的读者应该知道,在 Swin-Transformer 中,没有使用绝对位置编码,也即上文所说的那种与输入图像 patch 一一对应的位置编码;而是配合窗口注意力机制,使用了一种局限于窗口内部的相对位置编码机制。当我们改变输入图像的大小,可能会改变窗口的数量,但并不会影响窗口内部的相对位置编码。

那么 Swin-Transformer 是否天然地具备处理动态输入尺寸的能力呢?其实不尽然,在官方提供的分类 Swin-Transformer 实现中,我们依然需要指定输入图像的尺寸。这涉及到 Swin-Transformer 中的 shfit-window 注意力计算机制。

如上图所示,每个灰格代表一个图像 patch 对应的特征向量,而蓝色的格子则代表一个分窗,整张图就是图像的特征图。因为窗口偏移(shift)的原因,原本 4x4 的窗口大小,在边缘区域变成了一些更小的窗口。在 Swin-Transformer 中,为了高效计算这种情况下的窗口注意力,首先使用 torch.roll 函数,将原本的图像特征图循环偏移成右图所示的排布。之后,我们将这些原本小于 4x4 的边缘窗口组合,如 H 和 B 组合, I、G、C、A 组合,将所有窗口都拼凑成立了 4x4 的窗口。

但是如图的 H 和 B 虽然为了高效计算而临时组队成了一个窗口,但 H 窗口的特征向量不应该能注意到 B 窗口的特征向量。因此需要一个 mask,在计算属于 H 窗口的特征向量的注意力时,这个 mask 能够屏蔽属于 B 窗口的特征向量,使得 H 窗口只注意 H 窗口, B 窗口只注意 B 窗口。

为了便于理解 mask 的生成方式,我们以一个更小的特征图(4x4)及更小的窗口大小(2x2)为例,如下图所示,对特征图进行分窗,生成了 9 个窗口,对特征图进行偏移,并组合部分分窗后,生成了 4 个用于计算的分窗。这里每个窗口都对应了一种窗口组合情况,因此需要使用不同的 mask 来计算注意力。这里,我们以 attention_masks[1] 为例,其为一个 4 * 4 的矩阵,其中第 1 行只有第 1 列和第 3 列为白色,表示计算特征 ① 的注意力时,只考虑 ① 和 ③ 特征。

显而易见的是,需要生成多少 mask,取决于分窗后有多少个窗口;每个 mask 的内容,取决于对应窗口内的边缘窗口组合形式。而如果输入图片的尺寸发生变化,那么整体的特征图尺寸、分出的窗口数量也会发生变化,进而影响 mask 的计算。因此,如果要支持动态的输入尺寸,必须同样动态地生成这些 mask。

幸运的是,这种动态生成 mask 的计算量不高,也不会涉及到插值等操作。通过在前向推理时根据输入图像尺寸动态生成这些 mask,MMClassification 同样支持了 Swin-Transformer 的动态输入尺寸。

解决了以上两个问题,就可以使绝大部分 Transformer 结构的视觉主干网络支持动态的输入图像尺寸。

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2022-03-18,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 OpenMMLab 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档