本文介绍了Pytorch如何用ConvTranspose2d算子等价替代Upsample算子。

背景介绍:

  • 某些AI加速卡上Upsample算子的性能不够高,是否能用别的算子临时替代呢
  • 可以手动推断出ConvTranspose2d 的权值,使其与Upsample等价算子
  • 也可以搭建一个模型,输入分别给到ConvTranspose2d和Upsample算子,使它们之间的L1Loss最小
  • 当网络收敛后,对ConvTranspose2d的权值做舍入处理
  • 最后用上面的权值初始化ConvTranspose2d

网络结构

import onnx
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import numpy as np

class UpsampleModel(torch.nn.Module):
    def __init__(self):
        super(UpsampleModel, self).__init__()
        self.up=nn.Upsample(scale_factor=2, mode='nearest')
        self.deconv1=nn.ConvTranspose2d(3,3,2,2,groups=1,bias=False)     
    def forward(self, x):
        out0=self.up(x)
        out1=self.deconv1(x)
        return out0,out1

训练ConvTranspose2d的权值

def train():
    input_shape = (1, 3, 224, 224)
    model = UpsampleModel()
    criterion = nn.L1Loss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    for epoch in range(2100):
        running_loss = 0.0
        for i in range(100):
            input_data = torch.randn(input_shape)
            optimizer.zero_grad()
            out0,out1=model(input_data)
            loss = criterion(out0,out1)
            loss.backward() 
            optimizer.step() 
            running_loss += loss.item()        
        avg_loss=running_loss / 100
        print('[%d] loss: %f' % (epoch + 1,avg_loss ))
        running_loss = 0.0
        if avg_loss<1e-4:            
            w=model.deconv1.weight.detach().numpy()
            #print(w)       
            print(np.round(w))    
            break
train()            

结果

[[[[ 1.  1.]
   [ 1.  1.]]
  [[-0. -0.]
   [-0. -0.]]
  [[ 0. -0.]
   [-0. -0.]]]
 [[[ 0.  0.]
   [ 0. -0.]]
  [[ 1.  1.]
   [ 1.  1.]]
  [[ 0.  0.]
   [ 0. -0.]]]
 [[[ 0. -0.]
   [-0. -0.]]
  [[ 0. -0.]
   [ 0.  0.]]
  [[ 1.  1.]
   [ 1.  1.]]]]

用上面生成的权值验证

def val():
    w=np.array(
        [[[[ 1. , 1.],
           [ 1. , 1.]],
          [[ 0. , 0.],
           [ 0. , 0.]],
          [[ 0. , 0.],
           [ 0. , 0.]]],
         [[[ 0. , 0.],
           [ 0. , 0.]],
          [[ 1. , 1.],
           [ 1. , 1.]],
          [[ 0. , 0.],
           [ 0. , 0.]]],
         [[[ 0. , 0.],
           [ 0. , 0.]],
          [[ 0. , 0.],
           [ 0. , 0.]],
          [[ 1. , 1.],
           [ 1. , 1.]]]]   
            )
    input_shape = (1, 3, 224, 224)
    model = UpsampleModel().eval()
    model.deconv1.weight=torch.nn.Parameter(torch.from_numpy(w.astype(np.float32))) #设置权值
    input_data = torch.randn(input_shape)
    out0,out1=model(input_data)
    out0=out0.detach().numpy().reshape(-1)
    out1=out1.detach().numpy().reshape(-1)
    ret=(out0==out1).all()
val()    

输出

True
02-26 13:17