연구/Natural Language Processing

[Paper Review] A Simple and Effective Pruning Approach for Large Language Models

서히! 2025. 9. 7. 21:57

https://arxiv.org/abs/2306.11695

 

A Simple and Effective Pruning Approach for Large Language Models

As their size increases, Large Languages Models (LLMs) are natural candidates for network pruning methods: approaches that drop a subset of network weights while striving to preserve performance. Existing methods, however, require either retraining, which

arxiv.org

 

본 Paper Review에 앞서, 이 Paper를 선택한 이유

  • 현재 pruning 관련 연구를 진행 중이며, LLM(대형 언어 모델) 대상의 post-training pruning(사후 프루닝) 기법에 대한 실용적·간단한 베이스라인이 필요
  • Wanda는 계산 비용이 낮고 retraining 없이(혹은 최소의 데이터만으로) 적용 가능한 방법이라서 baseline/비교군으로 적절하다고 판단

0. Abstract

  • network weight의 subset을 drop하는기존 method에는 두가지 문제점이 존재
    1. 감당할 수 없는 retraining이 필요
    2. second-order information에 의존하는 reconstruction problem 발생
       Pruning by Weighs and activations (Wanda) 기법 제시
      • output 별로 input activation과 곱해진 가중치 크기가 가장 작은 것들을 pruning
      • retraining이나 weight를 update하지 않고, 직관적인 방법
Second Order Information
머신러닝에서 더 빨리 수렴하기 위해 gradient와 곡률(curvature) 정보를 활용하는 최적화 방법

 

 

1. Introduction

  • 대부분의 LLM은 상당한 많은 양의 parameter로 computational resource가 많이 필요
     양자화 방법(parameter를 저비트 수준의 representation으로 변환)
     Network pruning 방법(weight를 0으로 설정하여 제거)
  • 기존 pruning method의 문제점
    1. retraining (Liu et al., 2019; Blalock et al., 2020)
    2. random initialization (Zhu & Gupta, 2017; Louizos et al., 2018; Gale et al., 2019)
    3. extensive 반복 과정 (Frankle & Michael, 2019; Renda et al., 2020)
  • Wanda 기법  💡  
    • weight를 수정하지 않고 높은 sparsity의 pruning (= 많이 pruning이 가능해짐)
      • hidden state feature의 일부가 큰 값을 가진다는 것에서 착안
    • weight magnitude pruning metric + input activation의 norm 곱이 weight importance를 평가하는 데에 효과적
    • 따라서, linear layer의 각 output 안에서의 weight끼리 비교하여 우선순위가 작은 weight를 제거하는 method

 

2. Preliminaries

  1. Magnitude Pruning
    • 각 laher 안에서 weight locally하게 비교해서 threshold를 정함
      or
    • 전체 network에서 globally하게 정함
  2. Emergenet Large Magnitude Features
    • hidden state feature의 일부가 다른 것들에 비해 큰 수치를 가지고 있음 = outlier features
      1. 매우 큰 수치이며 (전형적인 hidden state value에 비해 100배 이상 차이 발생),
      2. 특정 feature 차원에서만 나타나고
      3. LLM의 predictive한 능력을 보기에 필수이기에 중요
  3. Calibration dataset
    • 목적pruning 과정에서 모델의 input acviation의 특성을 추정하기 위해 사용하는 소규모 데이터 샘플 모음, 각 입력 채널 h 의 활성화 통계를 추정()
    • 논문 권장값: C4에서 128 sequences를 사용(실험적으로 128이면 안정적이며 더 늘리면 완만 개선)

Wanda Overview, activation:X, weight: W

3. Wanda: Pruning by Weights and Activations

  • Pruning Metric
    채널 노름: n_j := ||X_j||2

 

  • input channel의 weight가 magnitude(절댓값) feature가 클 경우, weight의 magnitude가 작아도 imporatnce score가 커짐으로써 large magnitude feature가 Wanda에서는 유지됨을 알 수 있음
  • 프루닝 규칙: 각 output row 에 대해 S_i 를 정렬하고 하위 에 해당되는 weight를 0으로 만듦 (row-wise unstructured pruning)

⇒ SparseGPT(및 OBD/OBS) 관점에서의 local reconstruction objective를 단순화하면 Hessian의 대각 성분들이 분모에 나타난다. 대각 근사(=비상관 가정)과 damping λ→0 를 적용하면 중요도는 대략 위 식으로 유도 가능하다.
Wanda는 SparseGPT의 복잡한 2차 정보 지표를 효율적으로 근사하는 직관적 지표

 

기존 Magnitude 방식, sparse gpt, wanda와의 차이

헤시안 행렬
함수의 curvature(곡률)을 분석하기 위한 2차 미분 행렬
딥러닝에서는 loss function이나 레이어의 출력에 대해 미분을 여러 번 해서 얻는 행렬로, 각 파라미터 쌍의 곡률 정보를 담고 있음
Damping (댐핑)
헤시안 행렬을 역행렬로 계산하려면 불안정할 수 있으므로, 헤시안이 거의 0에 가까운 값이 생기면 역행렬 계산이 되지 않아 결과가 왜곡될 수 있으므로 이런 현상을 막기 위해 H에 작은 수 람다를 더해서 안정화시키는 방법
Pseudo Code
# W: (C_out, C_in), sum_sq: (C_in,) accumulated across calibration tokens
n = torch.sqrt(sum_sq)          # (C_in,)
metric = W.abs() * n.unsqueeze(0)  # broadcast -> (C_out, C_in)
k = int(C_in * sparsity)         # number to prune per row
# get indices of smallest k elements per row
vals, idx = torch.topk(metric, k, dim=1, largest=False)
W_pruned = W.clone()
W_pruned.scatter_(1, idx, 0.0)​

 

4. Experiments

Models & Baselines

    • 모델들: LLaMA 계열(7B, 13B, 30B, 65B) 및 LLaMA-2(7B, 13B, 70B)
    • Calibration data: C4에서 128 sequences (기본). 일부 실험은 1, 8, 128, 1024 등 다양한 샘플 수로 민감도 실험
    • 평가:
      • Zero-shot accuracy: 논문에서 7개 태스크 평균(예: PIQA, WinoGrande, ARC, HellaSwag)
      • Perplexity: WikiText(또는 텍스트 데이터 셋) 검증 셋에 대한 perplexity 측정
      • 속도: metric 계산 시간, 전체 프루닝 처리 시간, 그리고 구조화(2:4) 적용 시 matmul inference throughput 측정

4.1 Zero-Shot Tasks

 

 

 

  • Magnitude pruning → 성능 급락, 예: LLaMA-7B, 50% sparse에서 55% 수준
  • Wanda pruning → 50% sparse에서도 Dense 성능과 거의 유사, SparseGPT와 근접
  • SparseGPT → 여전히 가장 높은 성능이지만 Wanda와 차이가 작음
  • LLaMA-65B (50% sparse) 에서는 Wanda가 Dense와 거의 동일한 평균 zero-shot 정확도 달성 → Large Sparse 모델이 Small Dense 모델보다 나음 

4.2 Language Modeling

 

  • 주요 결과 (LLaMA-7B, 50% unstructured sparsity):
    • Dense: (기본 perplexity 수치)
    • Magnitude: 17.29
    • Wanda: 7.26
    • SparseGPT: 7.22
  • magnitude 대비 압도적 개선, SparseGPT와 사실상 동일 성능
  • 다른 모델 크기(LLaMA-13B, 30B, 65B / LLaMA-2 7B, 13B, 70B)에서도 동일 경향

4.3 Speed up

  • Wanda는 한 레이어의 metric 계산 시간이 SparseGPT 대비 수십 배 빠름.
  • 전체 모델 pruning도 훨씬 빠르게 끝남 → 대규모 LLM baseline 실험에 적합

 

왼쪽표: 누적 시간, 오른쪽표: 각 단계별 pruuning speed

 

 

5. Analysis

  • Fine-tuning 효과
    • Wanda로 pruning한 LLaMA-7B 모델에서 성능 하락 발생함
    • LoRA나 full parameter fine-tuning 적용하면 성능 복구 가능
    • LoRA는 적은 자원으로도 일정 수준 회복 가능
    • full fine-tuning은 dense 모델 성능에 근접하게 끌어올림

 

  • Pruning configuration
    • Wanda는 weight importance metric과 comparison group 모두 기존 방법과 다름
    • Ablation 실험 결과, per-output 기준의 grouping이 가장 효과적
    • magnitude pruning도 grouping 방식에 따라 성능 차이가 큼

 

 

  • Calibration data 개수 영향
    • Wanda는 calibration sample 수가 적어도 안정적
    • SparseGPT는 샘플 수 적으면 성능 불안정하지만, Wanda는 단 1개 샘플로도 성능 유지

 

 

  • Weight update 효과
    • SparseGPT는 weight update 필요하지만 Wanda는 거의 개선 X
    • unstructured 50%와 4:8 sparsity에서 update 불필요, 2:4 sparsity에서만 아주 미미한 개선 있음

 

6. Conclusion

  • Wanda는 LLM을 위해 단순하면서도 효과적인 pruning 방법 제시
  • 핵심은 weight 크기 × input activation norm 기반 pruning metric + per-output basis 비교
  • retraining이나 weight update 필요 없이 pretrained LLM에서 바로 sparsity 유도 가능
  • 실험 결과, magnitude pruning보다 훨씬 우수하고 SparseGPT에 준하는 성능 보임
  • Wanda는 빠른 속도와 간단한 구조 덕분에 sparse training 환경에도 유용할 가능성 있음