1.对于训练,可以使用DataParallel封装模型,使用多GPU训练,batch维度会被均匀的分布到每一块GPU上

2.对于测试

这是最近遇到的一个难题,因为做超分的时候输入图尺度可能并不小,这时会发生OOM,但是又无法再对其batch维度分割了。

2.1分割图像为多个小块

这是一种普遍的解决方案,但缺点是图像可能会因为切块产生斑块,此外PSNR和SSIM的测试也会出现较大误差(参考DBPN的论文)

2.2 Checkpoint,牺牲时间换显存

官网的定义

torch.utils.checkpoint.checkpoint(function*args)[source]

Checkpoint a model or part of the model

Checkpointing works by trading compute for memory. Rather than storing all intermediate activations of the entire computation graph for computing backward, the checkpointed part does not save intermediate activations, and instead recomputes them in backward pass. It can be applied on any part of a model.

Specifically, in the forward pass, function will run in torch.no_grad() manner, i.e., not storing the intermediate activations. Instead, the forward pass saves the inputs tuple and the functionparameter. In the backwards pass, the saved inputs and function is retreived, and the forward pass is computed on function again, now tracking the intermediate activations, and then the gradients are calculated using these activation values.

Warning

Checkpointing doesn’t work with torch.autograd.grad(), but only with torch.autograd.backward().

Warning

If function invocation during backward does anything different than the one during forward, e.g., due to some global variable, the checkpointed version won’t be equivalent, and unfortunately it can’t be detected.

在forward()的内部使用checkpoint,一次只计算若干层,并计算多次,可以有效处理显存溢出问题