pdf code

Container: Context Aggregation Network

这篇paper从临近矩阵的形成方式的角度,将MLP Mixer, CNN以及transformer统一起来.

Method:

一个残差网络层的计算表达可以抽象地表达为

其中是可学习的参数.

其中,定义关联矩阵, 代表邻域的关注,那么网络层可以表达为

其中是X的一个线性投影.通过引入不同的关联矩阵, 这个模块的拟合capacity可以进一步提升,其中可以采用multi-head的版本

Typical instance of the Context Aggregation Module

Transformer:

Depthwise Convolution:

这个矩阵的形态是静态的,值是可以学习的.

MLP-Mixer:

其计算公式为, 关联矩阵为

因而这个矩阵是完全静态的且完全可学习的。但是完全没有参数共享。

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., seq_l=196):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)

        #Uncomment this line for Container-PAM
        #self.static_a = nn.Parameter(torch.Tensor(1, num_heads, 1 + seq_l , 1 + seq_l))
        #trunc_normal_(self.static_a)

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).
        permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)

        #Uncomment this line for Container-PAM
        #attn = attn + self.static_a

        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x