EN KO
← 전체 논문 목록

Self-Guided Contrastive Learning for BERT Sentence Representations

ACL-IJCNLP 2021
Taeuk Kim, Kang Min Yoo, Sang-goo Lee

한줄 요약

BERT의 중간 레이어 표현을 자기 지도 신호로 활용하는 대조 학습 방법으로, 외부 데이터나 증강, 레이블 없이도 고품질 문장 임베딩을 생성하며 STS 벤치마크(예: BERT-base 기준 전체 STS 평균 74.62)와 다국어 STS, 다양한 전이 학습 태스크에서 우수한 성능을 달성합니다.

논문 개요
Figure 1. BERT(-base)의 레이어별 성능을 다양한 풀링 방법으로 STS-B 테스트 세트에서 측정한 결과. 선택된 레이어와 풀링 전략에 따라 성능이 극적으로 달라지며, 자기 안내 학습 방법(SG / SG-OPT)은 모든 기준선 대비 크게 향상된 결과를 달성합니다.

배경 및 동기

BERT와 같은 사전학습 언어 모델은 강력한 문맥적 단어 표현을 생성하지만, 이로부터 고품질의 문장 수준 표현을 도출하는 것은 여전히 어려운 과제입니다. 토큰 임베딩 평균이나 [CLS] 토큰을 사용하는 단순한 방법은 의미 유사도 태스크에서 GloVe와 같은 비문맥적 기준선보다도 낮은 성능을 보이는 경우가 많습니다. 이러한 현상은 비등방성(anisotropy) 문제라 불리며, BERT의 토큰 표현이 임베딩 공간에서 좁은 원뿔 형태로 분포하여 코사인 유사도가 의미적 유사성을 제대로 반영하지 못하기 때문에 발생합니다.

문제의 심각성: 사전 실험에서 저자들은 다양한 BERT 레이어와 풀링 방법의 조합으로 문장 임베딩을 구성하여 STS-B에서 테스트했습니다. BERT-base의 Spearman 상관계수(x100)는 16.71(10번째 레이어의 [CLS])부터 63.19(2번째 레이어의 max pooling)까지 범위를 보였습니다. 레이어/풀링 선택에 따라 거의 47점이나 차이가 나는 이 극적인 편차는, 최종 레이어 [CLS] 토큰을 사용하는 표준 관행이 최적과 거리가 멀며 개선의 여지가 크다는 것을 보여주었습니다.

기존 연구의 한계: 문장 임베딩을 위한 기존 대조 학습 방법들(예: Fang and Xie (2020)의 역번역 기반 CERT, NLI 지도 학습 기반 SBERT)은 긍정/부정 쌍 구성을 위해 외부 데이터셋이나 증강 파이프라인에 의존합니다. 역번역은 번역 모델이 필요하고 노이즈를 유발하며, NLI 지도 학습은 해당 데이터가 존재하는 도메인으로 적용을 제한합니다. 이러한 의존성은 특수 도메인 텍스트에 적용할 때 도메인 불일치 문제를 야기합니다.

핵심 관찰: BERT의 서로 다른 레이어는 서로 다른 수준의 언어적 추상화를 포착합니다 -- 하위 레이어는 표면적/구문적 특징을, 상위 레이어는 의미적 정보를 인코딩합니다(Jawahar et al., 2019). 이러한 자연스러운 계층 구조가 동일 문장에 대한 다양한 "관점(view)"을 제공하며, 중간 레이어 표현은 해당 문장을 개념적으로 대표하는 것이 보장되므로 대조 학습의 이상적인 피벗으로 활용할 수 있습니다.

이 관찰을 바탕으로, 저자들은 BERT의 최종 레이어 [CLS] 표현을 중간 레이어 표현과 대조하여 미세조정하는 Self-Guided Contrastive Learning을 제안합니다. 기본 NT-Xent 손실을 사용하는 SG와 최적화된 대조 목적함수를 사용하는 SG-OPT 두 가지 변형이 있으며, 외부 데이터, 증강, 추론 시 후처리가 일절 필요 없습니다.

제안 방법

본 방법은 비레이블 문장만을 사용하는 대조 학습 프레임워크를 통해 BERT를 미세조정합니다. 핵심 아이디어는 BERT를 두 복사본 -- BERTF(고정, 중간 레이어에서 학습 신호 제공)와 BERTT(조정, 최종 레이어 [CLS]가 문장 임베딩으로 최적화) -- 으로 복제하고, 맞춤형 대조 목적함수로 학습하는 것입니다.

1
이중 BERT 아키텍처와 자기 안내 쌍 구성
학습 시작 시 BERT를 BERTF(고정)와 BERTT(조정)로 복제합니다. 각 입력 문장 si에 대해, BERTF는 모든 레이어 k(0~l)에서 토큰 수준 은닉 표현 Hi,k를 생성합니다. 풀링 함수 p(기본적으로 max pooling)가 각 레이어의 토큰 표현을 문장 수준 관점 hi,k = p(Hi,k)로 변환합니다. 샘플링 함수 sigma가 이 중 하나 이상을 긍정 타겟 hi로 선택합니다. 한편 BERTT는 최종 레이어에서 문장 임베딩 ci = BERTT(si)[CLS]를 생성합니다. BERTF를 고정하는 핵심 설계 결정은 학습이 진행됨에 따라 학습 신호가 퇴화되는 것을 방지합니다 -- 두 복사본이 동시에 업데이트되면 중간 레이어 표현이 표류하여 다양한 관점으로서의 가치를 상실합니다.
2
기본 대조 손실 (Lbase)
b개 문장의 미니배치에서, 집합 X = {ci} U {hi}를 구성합니다. NT-Xent 손실을 모든 요소에 대해 계산합니다: Lbasem = -log(phi(xm, mu(xm)) / Z). 여기서 phi(u,v) = exp(g(f(u), f(v)) / tau)이고, g는 코사인 유사도, f는 2층 MLP 투영 헤드(은닉 크기 4096, GELU 활성화), tau는 온도 하이퍼파라미터, mu(.)는 각 ci를 대응하는 hi에 매핑하는 매칭 함수입니다. 정규화항 Lreg = ||BERTF - BERTT||22가 BERTT가 BERTF에서 너무 멀어지는 것을 방지합니다. 안정적 학습을 위해 BERTT의 0번째 레이어도 고정합니다.
3
최적화 손실 (Lopt) -- 세 가지 개선
저자들은 NT-Xent의 네 가지 상호작용 요소 -- (1) ci와 hi의 인력, (2) ci와 cj의 척력, (3) ci와 hj의 척력, (4) hi와 hj의 척력 -- 를 분석하여, (1)과 (3)만이 필수적임을 발견합니다. 세 가지 수정이 적용됩니다: (a) 손실을 ci 중심으로 재설계하여, hi는 접근할 타겟으로만 취급하고 앵커 자체로는 사용하지 않아 요소 (4)를 제거합니다. (b) 문장 임베딩 간 척력(요소 2)도 성능에 무의미하므로 제거합니다. (c) 단일 관점을 샘플링하는 대신, 모든 레이어 k=0..l의 다중 중간 레이어 관점 {hi,k}을 동시에 활용하여 더 풍부한 학습 신호를 제공합니다. 최종 최적화 손실 Lopt는 b개 문장과 l+1개 레이어에 대해 합산하고 b(l+1)로 정규화합니다.

학습 상세 설정: STS-B의 일반 문장(학습, 검증, 테스트 세트의 금표준 주석 미사용)을 사용하여 BERT를 미세조정합니다 -- BERT-flow와 동일한 설정입니다. 하이퍼파라미터는 STS-B 검증 세트에서 최적화: 온도 tau=0.01, 정규화 가중치 lambda=0.1. 투영 헤드 f는 은닉 크기 4096, GELU 활성화의 2층 MLP입니다. 모든 학습 모델 성능은 무작위성을 줄이기 위해 8회 실행의 평균으로 보고됩니다. 구현은 HuggingFace Transformers와 SBERT 라이브러리를 기반으로 하며, github.com/galsang/SG-BERT에서 공개되어 있습니다.

실험 결과

SG-CL은 의미 텍스트 유사도(STS) 태스크, SentEval 전이 학습 태스크, 다국어 STS에서 포괄적으로 평가되며, 계산 효율성도메인 강건성도 분석됩니다. BERT-base, BERT-large, SBERT-base, SBERT-large 네 가지 기본 모델에서 테스트되었습니다.

의미 텍스트 유사도 (Spearman 상관계수 x100, BERT-base)

모델풀링STS-BSICK-RSTS12STS13STS14STS15STS16평균
GloVeMean58.0253.7655.1470.6659.7368.2563.6661.32
미조정CLS20.3042.4221.5432.1121.2837.8944.2431.40
미조정Mean47.2958.2230.8759.8947.7360.2963.7352.57
FlowMean-271.3564.9564.3269.7263.6777.7769.5968.77
Contrastive (BT)CLS63.2766.9154.2664.0354.2868.1967.5062.63
Contrastive (SG)CLS75.0868.1963.6076.4867.5779.4274.8572.17
Contrastive (SG-OPT)CLS77.2370.2066.8480.1371.2381.5677.1774.62

SentEval 전이 학습 태스크 (정확도, BERT-base)

모델MRCRSUBJMPQASST2TRECMRPC평균
+ Mean pooling81.4686.7195.3787.9085.8390.3073.3685.85
+ WK pooling80.6485.5395.2788.6385.0394.0371.7185.83
+ SG-OPT82.4787.4295.4088.9286.2091.6074.2186.60

어블레이션 연구

어블레이션 연구는 전체 STS 테스트 세트에서 각 변형을 평가하여 모든 설계 결정을 체계적으로 검증합니다:

변형STS 평균변화
SG-OPT (Lopt3)74.62--
+ Lopt2 (단일 관점)73.14-1.48
+ Lopt1 (cj 척력 포함)72.61-2.01
+ SG (Lbase, 최적화 없음)72.17-2.45
tau = 0.1 (0.01 대비)70.39-4.23
lambda = 0.0 (정규화 없음)73.76-0.86
lambda = 1.0 (강한 정규화)73.18-1.44
투영 헤드 (f) 제거72.78-1.84

계산 효율성

방법학습 시간추론 시간 (초)
BERT + Mean pooling--13.94
BERT + WK pooling--197.03 (~3.3분)
BERT + Flow155.37초 (~2.6분)28.49
BERT + SG-OPT455.02초 (~7.5분)10.51

SG-OPT는 적절한 학습 시간(~7.5분)이 필요하지만, 학습 완료 후 후처리가 필요 없기 때문에 모든 방법 중 가장 빠른 추론(STS-B에서 10.51초)을 달성합니다. WK pooling은 추론 시 가장 느리고(197초), Flow도 흐름 기반 변환으로 인해 오버헤드가 발생합니다(28.5초).

주요 실험 결과

의의

본 연구는 문장 표현 학습 분야에 여러 중요한 기여를 했습니다:

링크

Representation Learning