Topic

FSDP (Fully Sharded Data Parallel)

JackerLab 2025. 9. 8. 06:00
728x90
반응형

개요

FSDP(Fully Sharded Data Parallel)는 PyTorch에서 제공하는 분산 학습 기법 중 하나로, 모델의 파라미터, 그래디언트, 옵티마이저 상태를 GPU 간에 샤딩(sharding)하여 메모리 사용량을 획기적으로 절감할 수 있는 방식입니다. 특히 GPT, BERT와 같은 초대규모 모델의 학습에 적합하며, 기존 DataParallel, DDP(Distributed Data Parallel) 방식의 메모리 한계를 극복합니다.

본 글에서는 FSDP의 원리, 구성, 주요 기술, 성능 효과 등을 분석하고 실제 적용 시 고려할 전략을 소개합니다.


1. 개념 및 정의

항목 설명
정의 FSDP는 모델 파라미터 및 상태 정보를 GPU 간에 분산 저장하고 통신을 통해 학습을 수행하는 분산 학습 방식입니다.
목적 초대규모 모델 학습 시 GPU 메모리 효율을 높이고 학습 가능한 모델 크기를 확장
필요성 기존 DDP 방식은 전체 모델을 각 GPU에 복제 → 메모리 병목 발생

FSDP는 메모리 최적화 + 통신 최적화를 결합한 고성능 학습 프레임워크입니다.


2. 특징

특징 설명 기존 방식과 비교
파라미터 샤딩 각 GPU가 전체 모델 중 일부만 보유 DDP는 전체 모델 복제
통신 최적화 All-Gather/Reduce-Scatter를 효율적으로 활용 NCCL 통신 레이어 최적화
Flat Parameter 구조 수천 개 파라미터를 하나의 덩어리로 처리 메타데이터 및 overhead 감소

FSDP는 메모리/통신/성능 간 균형을 최적화합니다.


3. 아키텍처 및 동작 방식

단계 설명 연산 흐름
1. Initialization 모델 파라미터를 균등하게 분할 및 각 GPU에 배치 Shard 생성
2. Forward Pass 필요 시 다른 GPU로부터 All-Gather → 순전파 실행 활성화 저장 최소화
3. Backward Pass 그래디언트 계산 후 Reduce-Scatter → 로컬 그래디언트 업데이트 메모리 재활용
4. Optimizer Step Optimizer state도 샤딩하여 업데이트 수행 Adam 등 상태 분산 저장

FSDP는 GPU 간 통신 비용과 메모리 효율을 정교하게 조율합니다.


4. 주요 기술 요소

기술 요소 설명 효과
Mixed Precision (FP16/BF16) 메모리 및 연산 최적화를 위한 저정밀 연산 학습 속도 향상 + 메모리 절약
CPU Offloading Optimizer state를 CPU 메모리로 이전 GPU 메모리 사용량 감소
Activation Checkpointing 중간 연산 결과를 저장하지 않고 재계산 메모리 사용량 대폭 절감
Wrap Policy 특정 모듈만 샤딩 대상으로 지정 가능 Layer-wise 샤딩 전략 가능

FSDP는 고급 사용자 설정이 가능한 매우 유연한 프레임워크입니다.


5. 장점 및 이점

장점 설명 기대 효과
메모리 효율 극대화 모델, 옵티마이저, 그래디언트 모두 샤딩 학습 가능한 모델 사이즈 증가
통신 효율 필요 시에만 All-Gather 수행 통신 병목 최소화
유연한 설정 다양한 샤딩/랩핑/정밀도 설정 가능 작업 환경에 최적화된 학습 설계 가능

FSDP는 100억 파라미터 이상 모델 학습 시 기본 선택지로 자리잡고 있습니다.


6. 활용 사례 및 고려사항

활용 사례 설명 고려 사항
GPT-3/LLAMA 학습 초대규모 모델 학습 시 GPU 메모리 한계 극복 Gradient Clipping, Offload 정책 등 세팅 필요
메모리 부족한 환경의 튜닝 1-2 GPU에서도 큰 모델 학습 가능 CPU Offload 시 속도 저하 가능성 고려
연구용 실험 환경 구성 Layer-wise profiling, WrapPolicy 조절 가능 Checkpointing + FP16 조합 시 오류 주의

성공적인 도입을 위해서는 통신량, 메모리 사용량, 학습 속도 간 trade-off 분석이 중요합니다.


7. 결론

FSDP는 PyTorch 기반에서 초대규모 모델 학습을 위한 강력한 분산 학습 전략으로, 메모리 샤딩, 통신 최적화, 정밀도 설정 등 다양한 기능을 통해 메모리 병목 문제를 해결합니다. 특히 Transformer 계열 대형 모델 학습에 필수적인 구성으로 자리잡았으며, 연구/상용 환경 모두에 적용할 수 있는 유연성과 성능을 제공합니다.

대규모 모델 학습 환경을 설계 중이라면, FSDP는 반드시 검토해야 할 핵심 기술입니다.

728x90
반응형