1行代码消除PyTorch的CUDA内存溢出报错,这个GitHub项目揽星600+
本文经AI新媒体量子位(公众号ID:QbitAI)授权转载,转载请联系出处。
多少人用PyTorch“炼丹”时都会被这个bug困扰。
![1行代码消除PyTorch的CUDA内存溢出报错,这个GitHub项目揽星600+](https://s5.51cto.com/oss/202112/20/d70aaeef95367a01707958fb7070ba50.jpg)
一般情况下,你得找出当下占显存的没用的程序,然后kill掉。
如果不行,还需手动调整batch size到合适的大小……
有点麻烦。
现在,有人写了一个PyTorch wrapper,用一行代码就能“无痛”消除这个bug。
![1行代码消除PyTorch的CUDA内存溢出报错,这个GitHub项目揽星600+](https://s6.51cto.com/oss/202112/20/de66d71c3e58c34323bd773408f5cf92.jpg)
有多厉害?
相关项目在GitHub才发布没几天就收获了600+星。
![1行代码消除PyTorch的CUDA内存溢出报错,这个GitHub项目揽星600+](https://s4.51cto.com/oss/202112/20/97b066244b61cbdce521baa4aa2ab1f1.jpg)
一行代码解决内存溢出错误
软件包名叫koila,已经上传PyPI,先安装一下:
- pip install koila
现在,假如你面对这样一个PyTorch项目:构建一个神经网络来对FashionMNIST数据集中的图像进行分类。
先定义input、label和model:
- # A batch of MNIST image
- input = torch.randn(8, 28, 28)
- # A batch of labels
- label = torch.randn(0, 10, [8])
- class NeuralNetwork(Module):
- def __init__(self):
- super(NeuralNetwork, self).__init__()
- self.flatten = Flatten()
- self.linear_relu_stack = Sequential(
- Linear(28 * 28, 512),
- ReLU(),
- Linear(512, 512),
- ReLU(),
- Linear(512, 10),
- )
- def forward(self, x):
- x = self.flatten(x)
- logits = self.linear_relu_stack(x)
- return logits
然后定义loss函数、计算输出和losses。
- loss_fn = CrossEntropyLoss()
- # Calculate losses
- out = nn(t)
- loss = loss_fn(out, label)
- # Backward pass
- nn.zero_grad()
- loss.backward()
好了,如何使用koila来防止内存溢出?
超级简单!
只需在第一行代码,也就是把输入用lazy张量wrap起来,并指定bacth维度——
koila就能自动帮你计算剩余的GPU内存并使用正确的batch size了。
在本例中,batch=0,则修改如下:
- input = lazy(torch.randn(8, 28, 28), batch=0)
完事儿!就这样和PyTorch“炼丹”时的OOM报错说拜拜。
灵感来自TensorFlow的静态/懒惰评估
下面就来说说koila背后的工作原理。
“CUDA error: out of memory”这个报错通常发生在前向传递(forward pass)中,因为这时需要保存很多临时变量。
koila的灵感来自TensorFlow的静态/懒惰评估(static/lazy evaluation)。
它通过构建图,并仅在必要时运行访问所有相关信息,来确定模型真正需要多少资源。
而只需计算临时变量的shape就能计算各变量的内存使用情况;而知道了在前向传递中使用了多少内存,koila也就能自动选择最佳batch size了。
又是算shape又是算内存的,koila听起来就很慢?
![1行代码消除PyTorch的CUDA内存溢出报错,这个GitHub项目揽星600+](https://s2.51cto.com/oss/202112/20/435162af4c8daa95e8fd6d90b5b971de.jpg)
NO。
即使是像GPT-3这种具有96层的巨大模型,其计算图中也只有几百个节点。
而Koila的算法是在线性时间内运行,任何现代计算机都能够立即处理这样的图计算;再加上大部分计算都是单个张量,所以,koila运行起来一点也不慢。
你又会问了,PyTorch Lightning的batch size搜索功能不是也可以解决这个问题吗?
是的,它也可以。
但作者表示,该功能已深度集成在自己那一套生态系统中,你必须得用它的DataLoader,从他们的模型中继承子类,才能训练自己的模型,太麻烦了。
而koila灵活又轻量,只需一行代码就能解决问题,非常“大快人心”有没有。
不过目前,koila还不适用于分布式数据的并行训练方法(DDP),未来才会支持多GPU。
![1行代码消除PyTorch的CUDA内存溢出报错,这个GitHub项目揽星600+](https://s5.51cto.com/oss/202112/20/50392f893df6ede88d5997f65229a1d5.jpg)
以及现在只适用于常见的nn.Module类。
![1行代码消除PyTorch的CUDA内存溢出报错,这个GitHub项目揽星600+](https://s6.51cto.com/oss/202112/20/e29fce83b555e825715af7c4aba23009.jpg)
ps. koila作者是一位叫做RenChu Wang的小哥。
![1行代码消除PyTorch的CUDA内存溢出报错,这个GitHub项目揽星600+](https://s4.51cto.com/oss/202112/20/bd10bbfd037472413476425ba58030f5.jpg)
项目地址:
https://github.com/rentruewang/koila
相关文章
- 金融服务领域的大数据:即时分析
- 影响大数据、机器学习和人工智能未来发展的8个因素
- 从0开始构建一个属于你自己的PHP框架
- 如何将Hadoop集成到工作流程中?这6个优秀实践必看
- SEO公司使用大数据优化其模型的5种方法
- 关于Web Workers你需要了解的七件事
- 深入理解HTTPS原理、过程与实践
- 增强分析:数据和分析的未来
- PHP协程实现过程详解
- AI专家:大数据知识图谱——实战经验总结
- 关于PHP的错误机制总结
- 利用数据分析量化协同过滤算法的两大常见难题
- 怎么做大数据工作流调度系统?大厂架构师一语点破!
- 2019大数据处理必备的十大工具,从Linux到架构师必修
- OpenCV中的KMeans算法介绍与应用
- 教大家如果搭建一套phpstorm+wamp+xdebug调试PHP的环境
- CentOS下三种PHP拓展安装方法
- Go语言HTTP Server源码分析
- Go语言HTTP Server源码分析
- 2017年4月编程语言排行榜:Hack首次进入前五十