EN KO
← All Publications

Self-Guided Contrastive Learning for BERT Sentence Representations

ACL-IJCNLP 2021
Taeuk Kim, Kang Min Yoo, Sang-goo Lee

One-Line Summary

A self-supervised contrastive learning method that leverages BERT's own intermediate-layer representations as guidance signals to produce high-quality sentence embeddings -- without any external data, augmentation, or labeled pairs -- achieving strong results on STS benchmarks (e.g., 74.62 avg. on all STS tasks with BERT-base) and diverse transfer tasks including multilingual STS.

Paper overview
Figure 1. BERT(-base)'s layer-wise performance with different pooling methods on the STS-B test set. The performance can vary dramatically depending on the selected layer and pooling strategy. The self-guided training methods (SG / SG-OPT) achieve much improved results compared to all baselines.

Background & Motivation

Pre-trained language models like BERT produce powerful contextual word representations, but deriving high-quality sentence-level representations from them remains a surprisingly difficult problem. Naive approaches such as averaging token embeddings or using the [CLS] token often yield sentence embeddings that underperform even non-contextual baselines like GloVe on semantic similarity tasks. This phenomenon, sometimes called the anisotropy problem, occurs because BERT's token representations tend to occupy a narrow cone in the embedding space, making cosine similarity an unreliable measure of semantic similarity.

How bad is the problem? In a preliminary experiment, the authors constructed sentence embeddings using various combinations of BERT layers and pooling methods, and tested them on STS-B. BERT-base's Spearman correlation (x100) ranged from as low as 16.71 ([CLS] at layer 10) to 63.19 (max pooling at layer 2). This dramatic variance -- nearly 47 points depending on the layer/pooling choice -- revealed that the standard practice of using the final-layer [CLS] token is far from optimal, and that there is substantial room for improvement.

Limitation of Prior Work: Existing contrastive learning methods for sentence embeddings (e.g., CERT by Fang and Xie (2020) using back-translation, or NLI-supervised methods like SBERT) rely on external datasets or augmentation pipelines to construct positive and negative pairs. Back-translation requires a translation model and introduces noise; NLI supervision limits applicability to domains where such data exists. These dependencies restrict their use in specialized text domains and introduce potential domain mismatch.

Key Insight: Different layers of BERT capture different levels of linguistic abstraction -- lower layers encode surface-level and syntactic features while higher layers encode more semantic information (Jawahar et al., 2019). This natural hierarchy provides a built-in source of diverse "views" of the same sentence, which can serve as a self-supervision signal for contrastive learning without any external data or augmentation. The intermediate representations are conceptually guaranteed to represent the corresponding sentence, making them ideal pivots for contrastive training.

Building on this insight, the authors propose Self-Guided Contrastive Learning, which fine-tunes BERT by contrasting its final-layer [CLS] representation against its own intermediate-layer representations. The method comes in two variants: SG (using the base NT-Xent loss) and SG-OPT (using a redesigned, optimized contrastive objective). The approach requires no external data, no augmentation, and no post-processing at inference time.

Proposed Method

The method fine-tunes BERT through a contrastive learning framework requiring only unlabeled sentences. The core idea is to clone BERT into two copies -- BERTF (fixed, providing training signals from intermediate layers) and BERTT (tuned, whose final-layer [CLS] is optimized as the sentence embedding) -- and train via a customized contrastive objective.

1
Dual-BERT Architecture with Self-Guided Pairs
BERT is cloned into BERTF (fixed) and BERTT (tuned) at the start of training. For each input sentence si, BERTF produces token-level hidden representations Hi,k at every layer k (0 to l). A pooling function p (max pooling by default) converts each layer's token representations into a sentence-level view hi,k = p(Hi,k). A sampling function sigma selects one or more of these views as the positive target hi. Meanwhile, BERTT produces the sentence embedding ci = BERTT(si)[CLS] from its final layer. The key design choice of keeping BERTF fixed prevents the training signal from degenerating as training proceeds -- if both copies were updated simultaneously, the intermediate-layer representations would drift and lose their value as diverse views.
2
Base Contrastive Loss (Lbase)
Given a mini-batch of b sentences, the set X = {ci} U {hi} is formed. The NT-Xent loss is computed over all elements: Lbasem = -log(phi(xm, mu(xm)) / Z), where phi(u,v) = exp(g(f(u), f(v)) / tau), g is cosine similarity, f is a two-layer MLP projection head (hidden size 4096, GELU activation), tau is a temperature hyperparameter, and mu(.) is a matching function that maps each ci to its corresponding hi and vice versa. A regularizer Lreg = ||BERTF - BERTT||22 prevents BERTT from diverging too far from BERTF. Layer 0 of BERTT is also frozen for stable learning.
3
Optimized Loss (Lopt) with Three Refinements
The authors analyze the four interaction factors in NT-Xent -- (1) ci toward hi, (2) ci away from cj, (3) ci away from hj, (4) hi away from hj -- and discover that only (1) and (3) are essential. Three modifications are applied: (a) The loss is re-centered on ci, treating hi only as targets to approach (not as anchors themselves), removing factor (4). (b) The repulsion between sentence embeddings (factor 2) is also removed as it proves insignificant. (c) Multiple intermediate-layer views {hi,k} for all layers k=0..l are used simultaneously instead of sampling a single one, yielding richer training signals. The final optimized loss Lopt sums over all b sentences and all l+1 layers, normalized by b(l+1).

Training Details: The method uses plain sentences from STS-B (training, validation, and test sets, without gold annotations) to fine-tune BERT -- identical to the setup used by BERT-flow. Hyperparameters are tuned on the STS-B validation set: temperature tau=0.01, regularization weight lambda=0.1. The projection head f is a two-layered MLP with hidden size 4096 and GELU activations. All trainable model performance is reported as the average of 8 separate runs to reduce randomness. The implementation is based on HuggingFace Transformers and SBERT, and is publicly available at github.com/galsang/SG-BERT.

Experimental Results

SG-CL is evaluated comprehensively on Semantic Textual Similarity (STS) tasks, SentEval transfer tasks, multilingual STS, and analyzed for computational efficiency and domain robustness. The method is tested across four base models: BERT-base, BERT-large, SBERT-base, and SBERT-large.

Semantic Textual Similarity (Spearman Correlation x100, BERT-base)

ModelPoolingSTS-BSICK-RSTS12STS13STS14STS15STS16Avg.
GloVeMean58.0253.7655.1470.6659.7368.2563.6661.32
No tuningCLS20.3042.4221.5432.1121.2837.8944.2431.40
No tuningMean47.2958.2230.8759.8947.7360.2963.7352.57
FlowMean-271.3564.9564.3269.7263.6777.7769.5968.77
Contrastive (BT)CLS63.2766.9154.2664.0354.2868.1967.5062.63
Contrastive (SG)CLS75.0868.1963.6076.4867.5779.4274.8572.17
Contrastive (SG-OPT)CLS77.2370.2066.8480.1371.2381.5677.1774.62

SentEval Transfer Tasks (Accuracy, BERT-base)

ModelMRCRSUBJMPQASST2TRECMRPCAvg.
+ Mean pooling81.4686.7195.3787.9085.8390.3073.3685.85
+ WK pooling80.6485.5395.2788.6385.0394.0371.7185.83
+ SG-OPT82.4787.4295.4088.9286.2091.6074.2186.60

Ablation Study

The ablation study systematically validates each design decision by evaluating variants on all STS test sets:

VariantSTS Avg.Delta
SG-OPT (Lopt3)74.62--
+ Lopt2 (single view)73.14-1.48
+ Lopt1 (with cj repulsion)72.61-2.01
+ SG (Lbase, no optimization)72.17-2.45
tau = 0.1 (vs. 0.01)70.39-4.23
lambda = 0.0 (no regularizer)73.76-0.86
lambda = 1.0 (strong regularizer)73.18-1.44
Without projection head (f)72.78-1.84

Computational Efficiency

MethodTraining TimeInference Time (sec)
BERT + Mean pooling--13.94
BERT + WK pooling--197.03 (~3.3 min)
BERT + Flow155.37s (~2.6 min)28.49
BERT + SG-OPT455.02s (~7.5 min)10.51

Although SG-OPT requires moderate training time (~7.5 minutes), it achieves the fastest inference of all methods (10.51s on STS-B) since no post-processing is needed once training is completed. WK pooling is the slowest at inference (197s), and Flow also incurs overhead (28.5s) due to its flow-based transformation.

Key Experimental Findings

Why It Matters

This work made several important contributions to the field of sentence representation learning:

Links

Representation Learning