Introduction

  • 在用预训练模型微调时,我们比较习惯于直接用 Transformer 最后一层的输出经过 FC / Bi-LSTM… 后输出最终结果。但实际上,Transformer 的每个层都捕捉的是不同粒度的语言信息 (i.e. with surface features in lower layers, syntactic features in middle layers, and semantic features in higher layers),因此有必要针对不同任务采用不同的 pooling strategy

Utilizing Transformer Representations Efficiently-LMLPHP


HuggingFace Transformers 在输入 input_idsattention_mask 后会得到 2 outputs (3 if configured). 下面主要讨论各种 pooling strategy 来综合利用这些输出

  • pooler output [batch_size, hidden_size] - Last layer hidden-state of the first token of the sequence (classification token) further processed by a Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence prediction (classification) objective during pretraining. We can deactivate pooler outputs by setting add pooling layer to False in model config and passing that to model.
  • last hidden state [batch_size, seq_Len, hidden_size] which is the sequence of hidden states at the output of the last layer.
  • hidden states [n_layers, batch_size, seq_len, hidden_size] - Hidden states for all layers and for all ids. (e.g. for base models, n_layers is 1 embed layer + 12 layers = 13. idx 0 is embed layer) Note: To unlock Transformer for giving hidden states as output we need to pass output_hidden_states parameter.

Different Pooling Strategies

Pooler Output

  • Pooler output + FC
logits = nn.Linear(config.hidden_size, 1)(pooler_output) # regression head

Last Hidden State Output

  • CLS Embeddings[CLS] Embed + FC
cls_embeddings = last_hidden_state[:, 0]
logits = nn.Linear(config.hidden_size, 1)(cls_embeddings) # regression head
  • Mean Pooling:Last Hidden State Output + Mean Pooling (remember to ignore padding tokens using attention masks)
input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
sum_mask = input_mask_expanded.sum(1)
sum_mask = torch.clamp(sum_mask, min=1e-9)
mean_embeddings = sum_embeddings / sum_mask
logits = nn.Linear(config.hidden_size, 1)(mean_embeddings) # regression head
  • Max Pooling:Last Hidden State Output + Max Pooling (remember to ignore padding tokens using attention masks, i.e. simply set masked token embeds’ value to 1e-9)
input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
last_hidden_state = last_hidden_state.clone()
last_hidden_state[input_mask_expanded == 0] = -1e9  # Set padding tokens to large negative value
max_embeddings = torch.max(last_hidden_state, 1)[0]
logits = nn.Linear(config.hidden_size, 1)(max_embeddings) # regression head
  • Mean + Max Pooling:(1) Last Hidden State Output + Max Pooling. (2) Last Hidden State Output + Mean Pooling. (3) Concat to have a final representation that is twice the hidden size.
mean_pooling_embeddings = torch.mean(last_hidden_state, 1)
_, max_pooling_embeddings = torch.max(last_hidden_state, 1)
mean_max_embeddings = torch.cat((mean_pooling_embeddings, max_pooling_embeddings), 1)
logits = nn.Linear(config.hidden_size*2, 1)(mean_max_embeddings)
  • Conv1D Pooling:Last Hidden State Output + 2 Conv1d layers
cnn1 = nn.Conv1d(768, 256, kernel_size=2, padding=1)
cnn2 = nn.Conv1d(256, 1, kernel_size=2, padding=1)

last_hidden_state = last_hidden_state.permute(0, 2, 1)	# [batch_size, embed_size, seq_len]
cnn_embeddings = F.relu(cnn1(last_hidden_state))	# [batch_size, 256, seq_len]
cnn_embeddings = cnn2(cnn_embeddings)	# [batch_size, 1, seq_len]
logits, _ = torch.max(cnn_embeddings, 2)	# [batch_size, 1]

Hidden States Output

Motivation

  • The output of the last layer may not always be the best representation of the input text during the fine-tuning for downstrea tasks.
  • For pre-trained language models, including Transformer, the most transferable contextualized representations of input text tend to occur in the middle layers, while the top layers specialize for language modeling. Therefore, the use of the last layer’s output may restrict the power of the pre-trained representation.

  • Layerwise CLS Embeddings:e.g. use second-to-last layer CLS Embeddings
all_hidden_states = torch.stack(outputs[2])	# [n_layers, batch_size, seq_len, hidden_size]
cls_embeddings = all_hidden_states[-2, :, 0] # layer_index+1 as we have 13 layers (embedding + num of blocks)
logits = nn.Linear(config.hidden_size, 1)(cls_embeddings) # regression head
  • Concatenate Pooling: Concatenate CLS Embeddings from different layers into one. e.g. Concatenate Last 4 Layers
all_hidden_states = torch.stack(outputs[2])
concatenate_pooling = torch.cat(
    (all_hidden_states[-1], all_hidden_states[-2], all_hidden_states[-3], all_hidden_states[-4]),-1
)
concatenate_pooling = concatenate_pooling[:, 0]
logits = nn.Linear(config.hidden_size*4, 1)(concatenate_pooling) # regression head
  • Weighted Layer Pooling: Token embeddings are the weighted mean of their different hidden layer representations. Averaged CLS Embed can be used as the final representation. Weighted Layer Pooling works the best of all pooling techniques be it any given task.
class WeightedLayerPooling(nn.Module):
    def __init__(self, num_hidden_layers, layer_start: int = 4, layer_weights = None):
        super(WeightedLayerPooling, self).__init__()
        self.layer_start = layer_start
        self.num_hidden_layers = num_hidden_layers
        self.layer_weights = layer_weights if layer_weights is not None \
            else nn.Parameter(
                torch.tensor([1] * (num_hidden_layers+1 - layer_start), dtype=torch.float)
            )

    def forward(self, all_hidden_states):
        all_layer_embedding = all_hidden_states[self.layer_start:, :, :, :]
        weight_factor = self.layer_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand(all_layer_embedding.size())	# [n_layer, batch_size, seq_len, embed_dim]
        weighted_average = (weight_factor*all_layer_embedding).sum(dim=0) / self.layer_weights.sum()
        return weighted_average		# [batch_size, seq_len, embed_dim]
    
layer_start = 9		# 9th ~ 12th hidden layer
pooler = WeightedLayerPooling(
    config.num_hidden_layers, 
    layer_start=layer_start, layer_weights=None
)
weighted_pooling_embeddings = pooler(all_hidden_states)	# [batch_size, seq_len, embed_dim]
weighted_pooling_embeddings = weighted_pooling_embeddings[:, 0]	# Get CLS Embed. [batch_size, embed_dim]
logits = nn.Linear(config.hidden_size, 1)(weighted_pooling_embeddings)
  • LSTM / GRU Pooling:Use a LSTM network to connect all intermediate representations of the [CLS] token, and the output of the last LSTM cell is used as the final representation.
    o = h L S T M L = L S T M ( h C L S i ) , i ∈ [ 1 , L ] o=h_{L S T M}^L=L S T M\left(h_{C L S}^i\right), i \in[1, L] o=hLSTML=LSTM(hCLSi),i[1,L]
class LSTMPooling(nn.Module):
    def __init__(self, num_layers, hidden_size, hiddendim_lstm):
        super(LSTMPooling, self).__init__()
        self.num_hidden_layers = num_layers
        self.hidden_size = hidden_size
        self.hiddendim_lstm = hiddendim_lstm
        self.lstm = nn.LSTM(self.hidden_size, self.hiddendim_lstm, batch_first=True)
        self.dropout = nn.Dropout(0.1)
    
    def forward(self, all_hidden_states):
        ## forward
        hidden_states = torch.stack([all_hidden_states[layer_i][:, 0].squeeze()
                                     for layer_i in range(1, self.num_hidden_layers+1)], dim=-1)		# [batch_size, embed_dim * num_hidden_layers]
        hidden_states = hidden_states.view(-1, self.num_hidden_layers, self.hidden_size)
        out, _ = self.lstm(hidden_states, None)
        out = self.dropout(out[:, -1, :])
        return out

hiddendim_lstm = 256
pooler = LSTMPooling(config.num_hidden_layers, config.hidden_size, hiddendim_lstm)
lstm_pooling_embeddings = pooler(all_hidden_states)
logits = nn.Linear(hiddendim_lstm, 1)(lstm_pooling_embeddings) # regression head
  • Attention Pooling:We can use a dot-product attention module to dynamically combine all intermediates
    o = softmax ⁡ ( q h C L S T ) h C L S W h o= \operatorname{softmax}\left(q h_{C L S}^T\right) h_{C L S}W_h o=softmax(qhCLST)hCLSWh其中, h C L S ∈ R n l × d h h_{CLS}\in\R^{n_l\times d_h} hCLSRnl×dh n l n_l nl 个 layer 的 CLS Embeds, q ∈ R 1 × d h , W h ∈ R d h × d f c q\in\R^{1\times d_h},W_h\in\R^{d_h\times d_{fc}} qR1×dh,WhRdh×dfc 为权重
class AttentionPooling(nn.Module):
    def __init__(self, num_layers, hidden_size, hiddendim_fc):
        super(AttentionPooling, self).__init__()
        self.num_hidden_layers = num_layers
        self.hidden_size = hidden_size
        self.hiddendim_fc = hiddendim_fc
        self.dropout = nn.Dropout(0.1)

        q_t = np.random.normal(loc=0.0, scale=0.1, size=(1, self.hidden_size))
        self.q = nn.Parameter(torch.from_numpy(q_t)).float()
        w_ht = np.random.normal(loc=0.0, scale=0.1, size=(self.hidden_size, self.hiddendim_fc))
        self.w_h = nn.Parameter(torch.from_numpy(w_ht)).float()

    def forward(self, all_hidden_states):
        hidden_states = torch.stack([all_hidden_states[layer_i][:, 0].squeeze()
                                     for layer_i in range(1, self.num_hidden_layers+1)], dim=-1)		# [batch_size, embed_dim * num_hidden_layers]
        hidden_states = hidden_states.view(-1, self.num_hidden_layers, self.hidden_size)	# [batch_size, num_hidden_layers, embed_dim]
        out = self.attention(hidden_states)
        out = self.dropout(out)
        return out

    def attention(self, h):
        v = torch.matmul(self.q, h.transpose(-2, -1)).squeeze(1)
        v = F.softmax(v, -1)
        v_temp = torch.matmul(v.unsqueeze(1), h).transpose(-2, -1)
        v = torch.matmul(self.w_h.transpose(1, 0), v_temp).squeeze(2)
        return v

hiddendim_fc = 128
pooler = AttentionPooling(config.num_hidden_layers, config.hidden_size, hiddendim_fc)
attention_pooling_embeddings = pooler(all_hidden_states)
logits = nn.Linear(hiddendim_fc, 1)(attention_pooling_embeddings) # regression head
class WKPooling(nn.Module):
    def __init__(self, layer_start: int = 4, context_window_size: int = 2):
        super(WKPooling, self).__init__()
        self.layer_start = layer_start
        self.context_window_size = context_window_size

    def forward(self, all_hidden_states):
        ft_all_layers = all_hidden_states
        org_device = ft_all_layers.device
        all_layer_embedding = ft_all_layers.transpose(1,0)
        all_layer_embedding = all_layer_embedding[:, self.layer_start:, :, :]  # Start from 4th layers output

        # torch.qr is slow on GPU (see https://github.com/pytorch/pytorch/issues/22573). So compute it on CPU until issue is fixed
        all_layer_embedding = all_layer_embedding.cpu()

        attention_mask = features['attention_mask'].cpu().numpy()
        unmask_num = np.array([sum(mask) for mask in attention_mask]) - 1  # Not considering the last item
        embedding = []

        # One sentence at a time
        for sent_index in range(len(unmask_num)):
            sentence_feature = all_layer_embedding[sent_index, :, :unmask_num[sent_index], :]
            one_sentence_embedding = []
            # Process each token
            for token_index in range(sentence_feature.shape[1]):
                token_feature = sentence_feature[:, token_index, :]
                # 'Unified Word Representation'
                token_embedding = self.unify_token(token_feature)
                one_sentence_embedding.append(token_embedding)

            ##features.update({'sentence_embedding': features['cls_token_embeddings']})

            one_sentence_embedding = torch.stack(one_sentence_embedding)
            sentence_embedding = self.unify_sentence(sentence_feature, one_sentence_embedding)
            embedding.append(sentence_embedding)

        output_vector = torch.stack(embedding).to(org_device)
        return output_vector

    def unify_token(self, token_feature):
        ## Unify Token Representation
        window_size = self.context_window_size

        alpha_alignment = torch.zeros(token_feature.size()[0], device=token_feature.device)
        alpha_novelty = torch.zeros(token_feature.size()[0], device=token_feature.device)

        for k in range(token_feature.size()[0]):
            left_window = token_feature[k - window_size:k, :]
            right_window = token_feature[k + 1:k + window_size + 1, :]
            window_matrix = torch.cat([left_window, right_window, token_feature[k, :][None, :]])
            Q, R = torch.qr(window_matrix.T)

            r = R[:, -1]
            alpha_alignment[k] = torch.mean(self.norm_vector(R[:-1, :-1], dim=0), dim=1).matmul(R[:-1, -1]) / torch.norm(r[:-1])
            alpha_alignment[k] = 1 / (alpha_alignment[k] * window_matrix.size()[0] * 2)
            alpha_novelty[k] = torch.abs(r[-1]) / torch.norm(r)

        # Sum Norm
        alpha_alignment = alpha_alignment / torch.sum(alpha_alignment)  # Normalization Choice
        alpha_novelty = alpha_novelty / torch.sum(alpha_novelty)

        alpha = alpha_novelty + alpha_alignment
        alpha = alpha / torch.sum(alpha)  # Normalize

        out_embedding = torch.mv(token_feature.t(), alpha)
        return out_embedding

    def norm_vector(self, vec, p=2, dim=0):
        ## Implements the normalize() function from sklearn
        vec_norm = torch.norm(vec, p=p, dim=dim)
        return vec.div(vec_norm.expand_as(vec))

    def unify_sentence(self, sentence_feature, one_sentence_embedding):
        ## Unify Sentence By Token Importance
        sent_len = one_sentence_embedding.size()[0]

        var_token = torch.zeros(sent_len, device=one_sentence_embedding.device)
        for token_index in range(sent_len):
            token_feature = sentence_feature[:, token_index, :]
            sim_map = self.cosine_similarity_torch(token_feature)
            var_token[token_index] = torch.var(sim_map.diagonal(-1))

        var_token = var_token / torch.sum(var_token)
        sentence_embedding = torch.mv(one_sentence_embedding.t(), var_token)

        return sentence_embedding
    
    def cosine_similarity_torch(self, x1, x2=None, eps=1e-8):
        x2 = x1 if x2 is None else x2
        w1 = x1.norm(p=2, dim=1, keepdim=True)
        w2 = w1 if x2 is x1 else x2.norm(p=2, dim=1, keepdim=True)
        return torch.mm(x1, x2.t()) / (w1 * w2.t()).clamp(min=eps)
pooler = WKPooling(layer_start=9)
wkpooling_embeddings = pooler(all_hidden_states)
logits = nn.Linear(config.hidden_size, 1)(wkpooling_embeddings) # regression head

More…

  • SWA, Apex AMP & Interpreting Transformers in Torch notebook is an implementation of the Stochastic Weight Averaging technique with NVIDIA Apex on transformers using PyTorch. The notebook also implements how to interactively interpret Transformers using LIT (Language Interpretability Tool) a platform for NLP model understanding.
  • On Stability of Few-Sample Transformer Fine-Tuning notebook goes over various remedies to increase few-sample fine-tuning stability and they show a significant performance improvement over simple finetuning methods.
  • Speeding up Transformer w/ Optimization Strategies notebook explains in-depth 5 optimization strategies with code. All these techniques are promising and can improve the model performance both in terms of speed and accuracy.
  • Other strategies: Dense Pooling, Word Weight (TF-IDF) Pooling, Async Pooling, Parallel / Heirarchical Aggregation

References

11-21 10:22