카테고리 없음

BAM! Just Like That: Simple and Efficient Parameter Upcycling for Mixture of Experts

서히! 2025. 8. 10. 23:08

https://arxiv.org/abs/2408.08274

 

BAM! Just Like That: Simple and Efficient Parameter Upcycling for Mixture of Experts

The Mixture of Experts (MoE) framework has become a popular architecture for large language models due to its superior performance over dense models. However, training MoEs from scratch in a large-scale regime is prohibitively expensive. Existing methods m

arxiv.org

 

Abstract

  • 문제정의: 기존 upcycling(예: BTX)은 여러 도메인으로 특화된 dense 모델들의 FFN만 MoE로 복사하고 attention 파라미터는 단순 평균화해 attention의 도메인별 특화 정보 손실 발생함
  • 핵심제안: BAM(Branch-Align-Mix) 파이프라인은 dense 모델을 분기(branch) → 도메인별 continued pretraining으로 특화시키고 → mixture 단계에서 FFN뿐 아니라 attention 파라미터까지 expert로 업사이클링함
  • 변형옵션: Expert-KV(각 expert에 Q/K/V/Wo 전부 복사)와 KV-Sharing(K/V 공유, Wo/Q 등은 expert별 유지) 두 옵션 제시해 성능/추론비용 트레이드오프 탐색함
  • 핵심메커니즘: attention은 soft-routing MoA로 처리해 라우터 불안정성 완화 및 토큰별 혼합 지식 활용 보장함
  • 실험요지: seed 590M~2B 모델 범위에서 perplexity와 downstream 성능이 BTX 대비 일관되게 개선됨

Figure 1: BAM operates in three phases.

Introduction

  • MoE의 장점: 인퍼런스 활성 파라미터를 적게 해서 compute 대비 parameter 효율을 높임
  • 현실적 제약: 맨 처음부터 MoE로 학습하기엔 대규모 자원·안정성 문제로 실무 적용에 제약 있음
  • 기존 upcycling 한계: BTX 등은 FFN expert는 살리지만 attention 파라미터는 평균화해 attention 전문성이 반영되지 못함
  • BAM 목표: dense expert들의 attention 특화 지식을 MoE로 직접 이전해 generalist 성능과 도메인 특화 성능을 동시에 끌어올리는 것
  • 주요 Contribution
    • attention 파라미터 전면 업사이클링 규칙 제안과 두 가지 KV 전략 비교 실험 수행함
    • soft-routing 기반 MoA와 parallel attention transformer 결합으로 throughput 최적화 시도함
    • 광범위한 ablation으로 attention 업사이클링의 본질적 기여 입증함

Related Work

  • Sparse upcycling 요약: single dense 모델 복제 후 FFN만 expert로 만드는 방식과의 차이점 정리함
  • BTX 세부: 여러 도메인으로 특화된 dense 모델을 만든 뒤 FFN만 MoE에 옮겨 성능을 개선하되 attention은 평균화하는 접근임
  • MoA 선행: 기존 MoA는 주로 Q/Wo만 expert로 두거나 K/V를 공유하는 디자인이 많았음, BAM은 K/V까지 적극적으로 expert로 올리는 변형을 실험함
  • 모델 합성/머징 관련: weight averaging, Fisher-weighted merge 등 다른 머징 기법과 비교해 attention을 전문가화하는 점에서 차별화됨

 

Background

  • Multi-Head Attention 정의: attention(Q,K,V)=softmax(QK^T/√d_k)V 구조와 Q=W_q x, K=W_k x, V=W_v x, output = W_o concat(heads) 
  • MoE 기본: 여러 FFN expert와 router로 top-k 선택, auxiliary load-balancing loss L_LB 사용해 expert 불균형 완화함
  • MoA 개념: attention 레이어를 expert화하면 Q/K/V/Wo의 일부 또는 전부를 expert 단위로 두는 설계가 가능함, 이때 routing 방식(soft vs sparse) 설계가 성능과 안정성에 결정적 영향 줌
  • Parallel Attention Transformer: attention과 FFN를 병렬로 처리해 latency/throughput 균형을 맞추는 구조 설명함

 

Method — BAM 전체 개요

  • 세 단계 흐름 요약: Branching → Continued Pretraining → Mixture Initialization & Training 순으로 진행됨
  • 핵심철학: dense expert의 로컬 최적화(도메인 성능)를 보전하면서 MoE의 조합 능력으로 generalist 성능을 확보함
  • 선택할 수 있는 디자인 축: attention의 어느 파라미터를 expert로 할지, routing을 soft로 할지 sparse로 할지, KV를 공유할지 말지 등이 주요 하이퍼파라미터임

 

Step 1 — Branching 세부

  • seed 모델 선택 기준: seed는 일반 도메인에서 충분히 학습된 dense 모델을 선택함, 논문은 590M와 2B 스케일 사용함
  • branching 수: 논문 실험은 보통 4 experts(일반+math+code+law) 구성 실험을 기본으로 보고함
  • 초기 파라미터: 복사본들은 초기 동일 파라미터로 시작하되 이후 도메인별 continued pretraining으로 분화함

Step 2 — Continued Pre-training 

  • 목적: 각 복사본을 특정 도메인 데이터로 추가 학습해 해당 도메인 성능을 극대화함, 이렇게 얻은 dense expert들이 MoE에서 전문성 발휘함
  • 토큰수 기준: 각 expert에 대해 약 100B tokens의 continued pretraining을 주된 실험 설정으로 사용함
  • 데이터 믹스: 각 도메인 데이터에 10% Common Crawl을 섞어 일반성 유지 관찰함
  • 학습세부: continued pretraining에서는 seed pretrain보다 낮은 peak LR을 사용(논문은 일반적으로 seed의 50% 수준 권장)해 안정성 확보함
  • 도메인별 eval set으로 overfit 여부와 generalization(공통 데이터에서의 성능 하락)을 지속 확인함

Step 3 — Mixture Initialization & Training 

  • 파라미터 초기화 규칙:
    • FFN parameters는 각 dense expert의 FFN을 그대로 복사해 MoE의 FFN expert로 초기화함
    • Attention parameters는 실험 옵션에 따라 Expert-KV 또는 KV-Sharing으로 초기화함
    • 비-expert 파라미터(embedding, layernorm 등)는 dense 모델들에서 산술 평균으로 초기화해 안정적 시작점 제공함
    • Router 파라미터는 random init 후 학습으로 맞춤함
  • Routing:
    • FFN은 top-1 sparse routing을 기본 사용해 활성 파라미터 절약함
    • Attention은 soft-routing을 사용해 각 토큰이 모든 attention expert의 가중합으로 처리되게 함, 이는 attention의 정보 손실을 방지하고 학습 안정성에 유리함
  • Auxiliary losses:
    • Load balancing loss L_LB 적용해 expert 할당 불균형 완화함
    • z-loss L_z 적용해 router logits의 수치적 폭주 방지함
  • 학습 안정성 팁: mixture 학습 초반 LR을 낮게 설정하고 warmup 길이를 늘려 gradient spike를 방지함, 논문은 large-scale에서 특히 LR을 많이 낮춤을 명시함

Attention expert 

  • Expert-KV 설명: W_q, W_k, W_v, W_o 등을 각 expert에 전부 복사해 attention 동작을 완전한 expert 단위로 만듦, 성능 최상이나 KV 캐시·메모리 비용 증가 초래함
  • KV-Sharing 설명: W_k, W_v는 공유해 KV 캐시를 하나로 유지하고 W_q/W_o 등만 expert별로 유지함, 추론 메모리·레이턴시 절감 가능하나 성능은 Expert-KV보다 소폭 낮을 수 있음
  • Soft-routing 동기: attention은 토큰 간 상호작용 핵심이므로 sparse hard routing은 정보 단절 유발 가능, soft-routing은 토큰이 여러 expert의 attention을 가중합으로 받게 해 지식 통합에 유리함
  •  

Experimental Setup

  • 공통 옵티마: AdamW, weight decay ~0.1, bfloat16 or mixed precision 권장해 메모리 절감함
  • Small-scale 세팅: seed 590M, seed pretrain ~400B tokens, continued pretrain each expert 100B tokens, mixture training에 총 25 TPU-core-days 사용함
  • Large-scale 세팅: seed 2B, seed pretrain ~750B tokens, continued pretrain each expert 100B tokens, mixture training에 총 305 TPU-core-days 사용함
  • LR 스케줄: continued pretrain은 seed 대비 peak LR의 50% 수준, mixture 학습은 더 낮은 peak LR + 긴 warmup으로 안정화함
  • 배치 및 토크나이저: 논문은 대형 배치(수백만 토큰 단위)로 pretrain을 수행하고 vocab ~256k 사용함
  • 시각자료 추천: Table 8(아키텍처)과 실험 리소스 표로 학습 스케줄을 정리하면 유용함

Data

  • Math 데이터셋 목록: MathGLM, GSM8K(중복 제거 주의), proof-pile-2, MathPile 등 사용함
  • Code 데이터셋: StarCoder 기반 코드 코퍼스 및 공개 코드 스택 활용함
  • Law 데이터셋: pile-of-law, HUPD 등 법률 문헌 중심 데이터로 전문성 확보함
  • 전처리 규칙: 중복 제거, 필터링(너무 짧거나 HTML 등 노이즈 제거), 토큰화 일관성 유지, 평가 데이터와 train 중복 방지 엄격 적용함
  • 데이터 믹스비: 각 도메인 데이터에 10% Common Crawl 섞어 일반성 보완, mixture training 시 도메인별 25% 균등 혼합으로 학습함

Model Architecture 세부(구조·계산량)

  • Small 모델(590M) 예: embedding dim 1024, FFN dim 4096, heads 8, layers 6 등으로 설정함
  • Large 모델(2B) 예: embedding dim 2304, FFN dim 18432, heads 18, layers 18 등으로 설정함
  • 파라미터 카운트와 active vs total 파라미터 수는 Table 9에 레이어별로 정리돼 있어 재현 시 참조 필수함
  • FLOPs 계산 지표: 논문은 non-embedding per-token FLOPs로 모델간 비교 제시함, BAM Expert-KV는 48,257,024 FLOPs, BTX는 21,510,144 FLOPs로 연산 증가 관찰됨

Results

Table1
Table2
Table3

  • Small-scale perplexity(Table 1): BTX 평균 10.72 → BAM DM(Expert-KV) 평균 9.70로 개선되어 MoE 초기화 정책의 효과 입증됨
  • Large-scale perplexity(Table 2): BTX 평균 3.27 → BAM DM(Expert-KV) 평균 3.19로 소폭이나 일관된 개선 관찰됨
  • Downstream 종합(Table 3): average score BTX 32.71 → BAM 34.02로 전반적 향상, 도메인별로 Code/Math/Law 등에서 더 큰 향상폭 관찰됨
  • attention 업사이클링은 단순 parameter matching으로 설명되지 않는 품질 향상을 제공함, 이는 ablation에서 확인됨

 

Ablation Studies

Table6

  • Attention 업사이클링 유무 비교(Table 4): attention 평균화한 모델 대비 attention을 expert로 올린 BAM이 더 우수하므로 attention의 역할 본질적임
  • Soft vs Sparse routing(Table 6): soft-routing MoA가 top-k sparse routing보다 안정적이고 전체 성능에서 우세함, 특히 토큰 수준의 attention 통합이 중요한 시나리오에서 이득 큼
  • Parameter-matched BTX 실험: 파라미터 수/활성 파라미터를 맞춰도 BAM의 성능을 재현 못함 → 단순 용량 확장이 아닌 attention 구조 변화가 핵심 기여임

Inference Efficiency & FLOPs

  • FLOPs 비교: BAM(Expert-KV) 48,257,024 FLOPs vs BTX 21,510,144 FLOPs로 BAM이 연산량 증가 유발함
  • Latency 측정: 예제 설정(16 token 생성, prompt len 256)에서 BTX 4.81s → BAM 6.17s로 느려짐, KV-Sharing으로 5.96s까지 개선 가능함
  • 메모리 관점: Expert-KV는 KV cache가 expert별로 늘어나 메모리 압박 심함, Shared-KV로 완화 가능하나 성능 소폭 저하 가능성 존재함
  • 최적화 전략: parallel attention, expert parallelism, KV 캐시 관리(공유/압축), quantization 등으로 실서비스 적용 가능성 탐색 필요함

Conclusion & Future Work

  • attention 파라미터까지 업사이클링하는 간단한 규칙이 MoE 성능을 실질적으로 향상시킴을 광범위한 실험으로 입증
  • 한계와 트레이드오프: 성능 이득은 있으나 FLOPs/메모리/레이턴시 비용 증가라는 실무적 비용 부담 존재
  • 향후 연구 제안: 데이터 믹스비 자동화, attention-expert 수·구성 자동 탐색(AutoML), KV 캐시 압축/공유 최적화, 하드웨어 친화적 expert parallelism 연구 권장