前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >TorchScript 系列解读(一):初识 TorchScript

TorchScript 系列解读(一):初识 TorchScript

作者头像
OpenMMLab 官方账号
发布2022-04-09 16:27:50
5510
发布2022-04-09 16:27:50
举报
文章被收录于专栏:OpenMMLabOpenMMLab

小伙伴们好呀,不久前我们推出了模型部署入门系列教程,受到了大家的一致好评,也收到了很多小伙伴的催更,后续教程正在准备中,将在不久后跟大家见面,敬请期待哦~

今天,我们又将开启新的 TorchScript 解读系列教程,带领大家玩转 PyTorch 模型部署。感兴趣的小伙伴一起往下看吧~

什么是 TorchScript

PyTorch 无疑是现在最成功的深度学习训练框架之一,是各种顶会顶刊论文实验的大热门。比起其他的框架,PyTorch 最大的卖点是它对动态网络的支持,比其他需要构建静态网络的框架拥有更低的学习成本。PyTorch 源码 Readme 中还专门为此做了一张动态图:

对研究员而言,PyTorch 能极大地提高想 idea、做实验、发论文的效率,是训练框架中的豪杰,但是它不适合部署。动态建图带来的优势对于性能要求更高的应用场景而言更像是缺点,非固定的网络结构给网络结构分析并进行优化带来了困难,多数参数都能以 Tensor 形式传输也让资源分配变成一件闹心的事。另外由于图是由 python 代码来构建的,一方面部署要依赖 python 环境,另一方面模型也毫无保密性可言。

而 TorchScript 就是为了解决这个问题而诞生的工具。包括代码的追踪及解析、中间表示的生成、模型优化、序列化等各种功能,可以说是覆盖了模型部署的方方面面。今天我们先简要地介绍一些 TorchScript 的功能,让大家有一个初步的认识,进阶的解读会陆续推出~

模型转换

作为模型部署的一个范式,通常我们都需要生成一个模型的中间表示(IR),这个 IR 拥有相对固定的图结构,所以更容易优化,让我们看一个例子:

代码语言:javascript
复制
import torch
from torchvision.models import resnet18

# 使用PyTorch model zoo中的resnet18作为例子
model = resnet18()
model.eval()

# 通过trace的方法生成IR需要一个输入样例
dummy_input = torch.rand(1, 3, 224, 224)

# IR生成
with torch.no_grad():
    jit_model = torch.jit.trace(model, dummy_input)

到这里就将 PyTorch 的模型转换成了 TorchScript 的 IR。这里我们使用了 trace 模式来生成 IR,所谓 trace 指的是进行一次模型推理,在推理的过程中记录所有经过的计算,将这些记录整合成计算图。关于 trace 的过程我们会在未来的分享中进行解读。

那么这个 IR 中到底都有些什么呢?我们可以可视化一下其中的 layer1 看看:

代码语言:javascript
复制
jit_layer1 = jit_model.layer1
print(jit_layer1.graph)

# graph(%self.6 : __torch__.torch.nn.modules.container.Sequential,
#       %4 : Float(1, 64, 56, 56, strides=[200704, 3136, 56, 1], requires_grad=0, device=cpu)):
#   %1 : __torch__.torchvision.models.resnet.___torch_mangle_10.BasicBlock = prim::GetAttr[name="1"](%self.6)
#   %2 : __torch__.torchvision.models.resnet.BasicBlock = prim::GetAttr[name="0"](%self.6)
#   %6 : Tensor = prim::CallMethod[name="forward"](%2, %4)
#   %7 : Tensor = prim::CallMethod[name="forward"](%1, %6)
#   return (%7)

是不是有点摸不着头脑?TorchScript 有它自己对于 Graph 以及其中元素的定义,对于第一次接触的人来说可能比较陌生,但是没关系,我们还有另一种可视化方式:

代码语言:javascript
复制
print(jit_layer1.code)

# def forward(self,
#     argument_1: Tensor) -> Tensor:
#   _0 = getattr(self, "1")
#   _1 = (getattr(self, "0")).forward(argument_1, )
#   return (_0).forward(_1, )

没错,就是代码!TorchScript 的 IR 是可以还原成 python 代码的,如果你生成了一个 TorchScript 模型并且想知道它的内容对不对,那么可以通过这样的方式来做一些简单的检查。

刚才的例子中我们使用 trace 的方法生成 IR。除了 trace 之外,PyTorch 还提供了另一种生成 TorchScript 模型的方法:script。这种方式会直接解析网络定义的 python 代码,生成抽象语法树 AST,因此这种方法可以解决一些 trace 无法解决的问题,比如对 branch/loop 等数据流控制语句的建图。script 方式的建图有很多有趣的特性,会在未来的分享中做专题分析,敬请期待。

模型优化

聪明的同学可能发现了,上面的可视化中只有 resnet18 里 forward 的部分,其中的子模块信息是不是丢失了呢?如果没有丢失,那么怎么样才能确定子模块的内容是否正确呢?别担心,还记得我们说过 TorchScript 支持对网络的优化吗,这里我们就可以用一个 pass 解决这个问题:

代码语言:javascript
复制
# 调用inline pass,对graph做变换
torch._C._jit_pass_inline(jit_layer1.graph)
print(jit_layer1.code)

# def forward(self,
#     argument_1: Tensor) -> Tensor:
#   _0 = getattr(self, "1")
#   _1 = getattr(self, "0")
#   _2 = _1.bn2
#   _3 = _1.conv2
#   _4 = _1.bn1
#   input = torch._convolution(argument_1, _1.conv1.weight, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1, False, False, True, True)
#   _5 = _4.running_var
#   _6 = _4.running_mean
#   _7 = _4.bias
#   input0 = torch.batch_norm(input, _4.weight, _7, _6, _5, False, 0.10000000000000001, 1.0000000000000001e-05, True)
#   input1 = torch.relu_(input0)
#   input2 = torch._convolution(input1, _3.weight, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1, False, False, True, True)
#   _8 = _2.running_var
#   _9 = _2.running_mean
#   _10 = _2.bias
#   out = torch.batch_norm(input2, _2.weight, _10, _9, _8, False, 0.10000000000000001, 1.0000000000000001e-05, True)
#   input3 = torch.add_(out, argument_1, alpha=1)
#   input4 = torch.relu_(input3)
#   _11 = _0.bn2
#   _12 = _0.conv2
#   _13 = _0.bn1
#   input5 = torch._convolution(input4, _0.conv1.weight, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1, False, False, True, True)
#   _14 = _13.running_var
#   _15 = _13.running_mean
#   _16 = _13.bias
#   input6 = torch.batch_norm(input5, _13.weight, _16, _15, _14, False, 0.10000000000000001, 1.0000000000000001e-05, True)
#   input7 = torch.relu_(input6)
#   input8 = torch._convolution(input7, _12.weight, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1, False, False, True, True)
#   _17 = _11.running_var
#   _18 = _11.running_mean
#   _19 = _11.bias
#   out0 = torch.batch_norm(input8, _11.weight, _19, _18, _17, False, 0.10000000000000001, 1.0000000000000001e-05, True)
#   input9 = torch.add_(out0, input4, alpha=1)
#   return torch.relu_(input9)

这里我们就能看到卷积、batch_norm、relu 等熟悉的算子了。

上面代码中我们使用了一个名为 inline 的 pass,将所有子模块进行内联,这样我们就能看见更完整的推理代码。pass 是一个来源于编译原理的概念,一个 TorchScript 的 pass 会接收一个图,遍历图中所有元素进行某种变换,生成一个新的图。我们这里用到的 inline 起到的作用就是将模块调用展开,尽管这样做并不能直接影响执行效率,但是它其实是很多其他 pass 的基础。PyTorch 中定义了非常多的 pass 来解决各种优化任务,未来我们会做一些更详细的介绍。

序列化

不管是哪种方法创建的 TorchScript 都可以进行序列化,比如:

代码语言:javascript
复制
# 将模型序列化
jit_model.save('jit_model.pth')
# 加载序列化后的模型
jit_model = torch.jit.load('jit_model.pth')

序列化后的模型不再与 python 相关,可以被部署到各种平台上。

PyTorch 提供了可以用于 TorchScript 模型推理的 c++ API,序列化后的模型终于可以不依赖 python 进行推理了:

代码语言:javascript
复制
// 加载生成的torchscript模型
auto module = torch::jit::load('jit_model.pth');
// 根据任务需求读取数据
std::vector<torch::jit::IValue> inputs = ...;
// 计算推理结果
auto output = module.forward(inputs).toTensor();

与其他组件的关系

与 torch.onnx 的关系

ONNX 是业界广泛使用的一种神经网络中间表示,PyTorch 自然也对 ONNX 提供了支持。torch.onnx.export 函数可以帮助我们把 PyTorch 模型转换成 ONNX 模型,这个函数会使用 trace 的方式记录 PyTorch 的推理过程。聪明的同学可能已经想到了,没错,ONNX 的导出,使用的正是 TorchScript 的 trace 工具。具体步骤如下:

1. 使用 trace 的方式先生成一个 TorchScipt 模型,如果你转换的本身就是 TorchScript 模型,则可以跳过这一步。

2. 使用许多 pass 对 1 中生成的模型进行变换,其中对 ONNX 导出最重要的一个 pass 就是ToONNX,这个 pass 会进行一个映射,将 TorchScript 中 prim、aten 空间下的算子映射到onnx空间下的算子。

3. 使用 ONNX 的 proto 格式对模型进行序列化,完成 ONNX 的导出。

关于 ONNX 导出的实现以及算子映射的方式将会在未来的分享中详细展开。

与 torch.fx 的关系

PyTorch1.9 开始添加了 torch.fx 工具,根据官方的介绍,它由符号追踪器 (symbolic tracer),中间表示(IR), Python 代码生成 (Python code generation) 等组件组成,实现了 python->python 的翻译。是不是和 TorchScript 看起来有点像?

其实他们之间联系不大,可以算是互相垂直的两个工具,为解决两个不同的任务而诞生。

TorchScript 的主要用途是进行模型部署,需要记录生成一个便于推理优化的 IR,对计算图的编辑通常都是面向性能提升等等,不会给模型本身添加新的功能。

FX 的主要用途是进行 python->python 的翻译,它的 IR 中节点类型更简单,比如函数调用、属性提取等等,这样的 IR 学习成本更低更容易编辑。使用 FX 来编辑图通常是为了实现某种特定功能,比如给模型插入量化节点等,避免手动编辑网络造成的重复劳动。

这两个工具可以同时使用,比如使用 FX 工具编辑模型来让训练更便利、功能更强大;然后用 TorchScript 将模型加速部署到特定平台。

希望通过以上的分享,大家对 TorchScript 有了一个初步的认识,未来我们将会为大家带来更进阶的解读,欢迎大家持续关注。另外值得分享的是,MMDeploy 已开始对 TorchScript 提供支持

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2022-03-24,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 OpenMMLab 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
相关产品与服务
文件存储
文件存储(Cloud File Storage,CFS)为您提供安全可靠、可扩展的共享文件存储服务。文件存储可与腾讯云服务器、容器服务、批量计算等服务搭配使用,为多个计算节点提供容量和性能可弹性扩展的高性能共享存储。腾讯云文件存储的管理界面简单、易使用,可实现对现有应用的无缝集成;按实际用量付费,为您节约成本,简化 IT 运维工作。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档