卷积残差模块算子融合
torch.nn.conv2d()
:
conv2d(input_channels, outout_channells, kernel_size, padding, stride, group)
:
group:将卷积核分成group组
conv = torch.nn.conv2d(2, 4, 3, group=2)
conv.weight.size()
>>> toorch.Size([2, 1, 3, 3])
# 也就是将原本[4,1,3,3]在intput_channel维度上切分了
# 1.可以降低运算次数
# 2.考虑的通道也改变了
将1x1_conv
和原始的input
都变为3x3_conv
做算子融合
# res_block = 3*3_conv + 1*1_conv + input
import torch
import torch.nn.functional as F
import torch.nn as nn
in_channels, out_channels, kernel_size, w, h = 2, 2, 3, 9, 9
x = torch.ones(1, in_channels, w, h)
# 方法1(原生):
conv_2d = nn.Conv2d(in_channels, out_channels, kernel_size, padding='same')
# 1x1卷积
conv_2d_point_wise = nn.Conv2d(in_channels, out_channels, 1)
result_1 = conv_2d(x) + conv_2d_point_wise(x) + x
# 方法2(算子融合):
# 将point_wise和x本身都写成3x3的卷积
# 最终将3个卷积合并成一个
# [2,2,1,1] -> [2,2,3,3]
point_wise_to_conv_weight = F.pad(conv_2d_point_wise.weight, [1, 1, 1, 1])
conv_2d_for_point_wise = nn.Conv2d(in_channels, out_channels, kernel_size, padding='same')
conv_2d_for_point_wise.weight = nn.Parameter(point_wise_to_conv_weight)
conv_2d_for_point_wise.bias = conv_2d_point_wise.bias
# input -> 3x3_conv
zeros = torch.unsqueeze(torch.zeros(kernel_size, kernel_size), 0)
stars = torch.unsqueeze(F.pad(torch.ones(1, 1), [1, 1, 1, 1]), 0)
stars_zeros = torch.unsqueeze(torch.cat([stars, zeros], dim=0), 0)
zeros_stars = torch.unsqueeze(torch.cat([zeros, stars], dim=0), 0)
# stars_zeros是取出input的第一个channel的信息
# zeros_stars是取出input的第二个channel的信息
identity_to_conv_weight = torch.cat([stars_zeros, zeros_stars], dim=0)
identity_to_conv_bias = torch.zeros([out_channels])
conv_2d_for_identity = nn.Conv2d(in_channels, out_channels, kernel_size, padding='same')
conv_2d_for_identity.weight = nn.Parameter(identity_to_conv_weight)
conv_2d_for_identity.bias = nn.Parameter(identity_to_conv_bias)
result2 = conv_2d(x) + conv_2d_for_point_wise(x) + conv_2d_for_identity(x)
# 现在我们将三个算子融合
conv_2d_for_fusion = nn.Conv2d(in_channels, out_channels, kernel_size, padding='same')
conv_2d_for_fusion.weight = nn.Parameter(
conv_2d.weight.data + conv_2d_for_point_wise.weight.data + conv_2d_for_identity.weight.data)
conv_2d_for_fusion.bias = nn.Parameter(
conv_2d.bias.data + conv_2d_for_point_wise.bias.data + conv_2d_for_identity.bias.data)
result3 = conv_2d_for_fusion(x)
torch.all(torch.isclose(result_1, result3))
torch.all(torch.isclose(result_1, result2))