개요
FSDP(Fully Sharded Data Parallel)는 PyTorch에서 제공하는 분산 학습 기법 중 하나로, 모델의 파라미터, 그래디언트, 옵티마이저 상태를 GPU 간에 샤딩(sharding)하여 메모리 사용량을 획기적으로 절감할 수 있는 방식입니다. 특히 GPT, BERT와 같은 초대규모 모델의 학습에 적합하며, 기존 DataParallel, DDP(Distributed Data Parallel) 방식의 메모리 한계를 극복합니다.
본 글에서는 FSDP의 원리, 구성, 주요 기술, 성능 효과 등을 분석하고 실제 적용 시 고려할 전략을 소개합니다.
1. 개념 및 정의
항목 | 설명 |
정의 | FSDP는 모델 파라미터 및 상태 정보를 GPU 간에 분산 저장하고 통신을 통해 학습을 수행하는 분산 학습 방식입니다. |
목적 | 초대규모 모델 학습 시 GPU 메모리 효율을 높이고 학습 가능한 모델 크기를 확장 |
필요성 | 기존 DDP 방식은 전체 모델을 각 GPU에 복제 → 메모리 병목 발생 |
FSDP는 메모리 최적화 + 통신 최적화를 결합한 고성능 학습 프레임워크입니다.
2. 특징
특징 | 설명 | 기존 방식과 비교 |
파라미터 샤딩 | 각 GPU가 전체 모델 중 일부만 보유 | DDP는 전체 모델 복제 |
통신 최적화 | All-Gather/Reduce-Scatter를 효율적으로 활용 | NCCL 통신 레이어 최적화 |
Flat Parameter 구조 | 수천 개 파라미터를 하나의 덩어리로 처리 | 메타데이터 및 overhead 감소 |
FSDP는 메모리/통신/성능 간 균형을 최적화합니다.
3. 아키텍처 및 동작 방식
단계 | 설명 | 연산 흐름 |
1. Initialization | 모델 파라미터를 균등하게 분할 및 각 GPU에 배치 | Shard 생성 |
2. Forward Pass | 필요 시 다른 GPU로부터 All-Gather → 순전파 실행 | 활성화 저장 최소화 |
3. Backward Pass | 그래디언트 계산 후 Reduce-Scatter → 로컬 그래디언트 업데이트 | 메모리 재활용 |
4. Optimizer Step | Optimizer state도 샤딩하여 업데이트 수행 | Adam 등 상태 분산 저장 |
FSDP는 GPU 간 통신 비용과 메모리 효율을 정교하게 조율합니다.
4. 주요 기술 요소
기술 요소 | 설명 | 효과 |
Mixed Precision (FP16/BF16) | 메모리 및 연산 최적화를 위한 저정밀 연산 | 학습 속도 향상 + 메모리 절약 |
CPU Offloading | Optimizer state를 CPU 메모리로 이전 | GPU 메모리 사용량 감소 |
Activation Checkpointing | 중간 연산 결과를 저장하지 않고 재계산 | 메모리 사용량 대폭 절감 |
Wrap Policy | 특정 모듈만 샤딩 대상으로 지정 가능 | Layer-wise 샤딩 전략 가능 |
FSDP는 고급 사용자 설정이 가능한 매우 유연한 프레임워크입니다.
5. 장점 및 이점
장점 | 설명 | 기대 효과 |
메모리 효율 극대화 | 모델, 옵티마이저, 그래디언트 모두 샤딩 | 학습 가능한 모델 사이즈 증가 |
통신 효율 | 필요 시에만 All-Gather 수행 | 통신 병목 최소화 |
유연한 설정 | 다양한 샤딩/랩핑/정밀도 설정 가능 | 작업 환경에 최적화된 학습 설계 가능 |
FSDP는 100억 파라미터 이상 모델 학습 시 기본 선택지로 자리잡고 있습니다.
6. 활용 사례 및 고려사항
활용 사례 | 설명 | 고려 사항 |
GPT-3/LLAMA 학습 | 초대규모 모델 학습 시 GPU 메모리 한계 극복 | Gradient Clipping, Offload 정책 등 세팅 필요 |
메모리 부족한 환경의 튜닝 | 1-2 GPU에서도 큰 모델 학습 가능 | CPU Offload 시 속도 저하 가능성 고려 |
연구용 실험 환경 구성 | Layer-wise profiling, WrapPolicy 조절 가능 | Checkpointing + FP16 조합 시 오류 주의 |
성공적인 도입을 위해서는 통신량, 메모리 사용량, 학습 속도 간 trade-off 분석이 중요합니다.
7. 결론
FSDP는 PyTorch 기반에서 초대규모 모델 학습을 위한 강력한 분산 학습 전략으로, 메모리 샤딩, 통신 최적화, 정밀도 설정 등 다양한 기능을 통해 메모리 병목 문제를 해결합니다. 특히 Transformer 계열 대형 모델 학습에 필수적인 구성으로 자리잡았으며, 연구/상용 환경 모두에 적용할 수 있는 유연성과 성능을 제공합니다.
대규모 모델 학습 환경을 설계 중이라면, FSDP는 반드시 검토해야 할 핵심 기술입니다.