【ViT】Vision Transformer的实现01 patch embedding-LMLPHP
对于224*224的图像,将它输入到Transformer里面,就需要将图像展开成一系列的token,
如果逐像素视为token进行注意力的计算,难免计算量太大,因此一个更加合理的想法是将图像划分为一个个的patch
将每个patch进行embedding

现在对于一个224224的图像,我们设置每个patch图像块的尺寸是16,因此呢,我们可以从H和W两个维度将原图像进行分割,
224/16=14 14
14=196
也就说说把原图像分割成为了196个16*16的patch,因此我们就是要把196个patch进行embedding

class PatchEmbed(nn.Module):
    """
    2D Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None):
        super().__init__()
        img_size = (img_size, img_size)
        patch_size = (patch_size, patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
        self.num_patches = self.grid_size[0] * self.grid_size[1]

        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."

        # flatten: [B, C, H, W] -> [B, C, HW]
        # transpose: [B, C, HW] -> [B, HW, C]
        x = self.proj(x).flatten(2).transpose(1, 2)
        x = self.norm(x)
        return x

我们首先定义一个PatchEmbed类,
它的初始化函数输入参数如下
img_size是输入图像的大小,224 像素
patch_size是我们分割的图像块的大小 16 像素
in_c 是输入图像的channels通道数,是3 RGB
embed_dim是每一个token的特征维数,也是我们embed输入的维数 我们定为768

self.num_patches = self.grid_size[0] * self.grid_size[1]
这一步相当于是计算了token的个数 1414=196
下面我们借用pytorch的2D卷积函数 nn.Conv2d
输入的通道数是3 embed_dim既是我们要求的输出通道数即每个token的特征维数,同时在卷积运算里面这代表着这层有多少个卷积核,
因为一般的卷积,卷积层输出通道数等于卷积核的个数。
然后卷积核的大小就是patch的大小16,步长的大小也是patch的大小16,这个意思就相当于用16
16的卷积核以16的步长做卷积,实际上就是提取了一个14*14的特征图,相当于把patch都提取出来了

下面的前向计算函数里面

x = self.proj(x).flatten(2).transpose(1, 2)

x首先的输入是(8,3,224,224)
然后通过proj函数做了embedding 输出的是(8,768,14,164)
然后flatten(2),在第二维上进行展开 (8,768,14,14)变成了(8,768,14*14)
最后transpose(1, 2)把第一维和第二维做交换 (8,768,196)变成了(8,196,768)

03-09 01:57