AcWing
  • 首页
  • 活动
  • 题库
  • 竞赛
  • 应用
  • 更多
    • 题解
    • 分享
    • 商店
    • 问答
    • 吐槽
  • App
  • 登录/注册

卷积残差模块算子融合

作者: 作者的头像   har ,  2023-09-27 16:49:49 ,  所有人可见 ,  阅读 50


0


卷积残差模块算子融合

torch.nn.conv2d():

conv2d(input_channels, outout_channells, kernel_size, padding, stride, group):

group:将卷积核分成group组

conv = torch.nn.conv2d(2, 4, 3, group=2)
conv.weight.size()
>>> toorch.Size([2, 1, 3, 3])

# 也就是将原本[4,1,3,3]在intput_channel维度上切分了
# 1.可以降低运算次数
# 2.考虑的通道也改变了

image-20230927155456402.png

将1x1_conv和原始的input都变为3x3_conv做算子融合

# res_block = 3*3_conv + 1*1_conv + input

import torch
import torch.nn.functional as F
import torch.nn as nn

in_channels, out_channels, kernel_size, w, h = 2, 2, 3, 9, 9
x = torch.ones(1, in_channels, w, h)

# 方法1(原生):
conv_2d = nn.Conv2d(in_channels, out_channels, kernel_size, padding='same')
# 1x1卷积
conv_2d_point_wise = nn.Conv2d(in_channels, out_channels, 1)
result_1 = conv_2d(x) + conv_2d_point_wise(x) + x

# 方法2(算子融合):
# 将point_wise和x本身都写成3x3的卷积
# 最终将3个卷积合并成一个

# [2,2,1,1] -> [2,2,3,3]
point_wise_to_conv_weight = F.pad(conv_2d_point_wise.weight, [1, 1, 1, 1])
conv_2d_for_point_wise = nn.Conv2d(in_channels, out_channels, kernel_size, padding='same')
conv_2d_for_point_wise.weight = nn.Parameter(point_wise_to_conv_weight)
conv_2d_for_point_wise.bias = conv_2d_point_wise.bias

# input -> 3x3_conv
zeros = torch.unsqueeze(torch.zeros(kernel_size, kernel_size), 0)
stars = torch.unsqueeze(F.pad(torch.ones(1, 1), [1, 1, 1, 1]), 0)
stars_zeros = torch.unsqueeze(torch.cat([stars, zeros], dim=0), 0)
zeros_stars = torch.unsqueeze(torch.cat([zeros, stars], dim=0), 0)
# stars_zeros是取出input的第一个channel的信息
# zeros_stars是取出input的第二个channel的信息

identity_to_conv_weight = torch.cat([stars_zeros, zeros_stars], dim=0)
identity_to_conv_bias = torch.zeros([out_channels])
conv_2d_for_identity = nn.Conv2d(in_channels, out_channels, kernel_size, padding='same')
conv_2d_for_identity.weight = nn.Parameter(identity_to_conv_weight)
conv_2d_for_identity.bias = nn.Parameter(identity_to_conv_bias)

result2 = conv_2d(x) + conv_2d_for_point_wise(x) + conv_2d_for_identity(x)

# 现在我们将三个算子融合
conv_2d_for_fusion = nn.Conv2d(in_channels, out_channels, kernel_size, padding='same')

conv_2d_for_fusion.weight = nn.Parameter(
    conv_2d.weight.data + conv_2d_for_point_wise.weight.data + conv_2d_for_identity.weight.data)

conv_2d_for_fusion.bias = nn.Parameter(
    conv_2d.bias.data + conv_2d_for_point_wise.bias.data + conv_2d_for_identity.bias.data)
result3 = conv_2d_for_fusion(x)

torch.all(torch.isclose(result_1, result3))
torch.all(torch.isclose(result_1, result2))

0 评论

你确定删除吗?

© 2018-2023 AcWing 版权所有  |  京ICP备2021015969号-2
用户协议  |  隐私政策  |  常见问题  |  联系我们
AcWing
请输入登录信息
更多登录方式: 微信图标 qq图标 qq图标
请输入绑定的邮箱地址
请输入注册信息