Topic

FlashAttention

JackerLab 2025. 5. 16. 01:50
728x90
반응형

개요

FlashAttention은 GPU 메모리 대역폭을 효율적으로 활용하여 Transformer 모델의 Attention 연산을 빠르고 정확하게 수행할 수 있도록 설계된 메모리 최적화 기반의 고속 어텐션 알고리즘입니다. 기존의 Softmax Attention 구현 방식은 쿼리-키-값 연산 시 많은 중간 메모리를 생성하고 이를 반복적으로 읽고 쓰는 비효율적인 구조였으나, FlashAttention은 이 과정을 최소화하여 큰 시퀀스 길이에서도 속도와 정확성을 유지할 수 있도록 합니다.


1. 개념 및 정의

항목 설명
정의 FlashAttention은 CUDA 기반 커널을 활용하여 attention score와 softmax 계산을 메모리 낭비 없이 수행하는 고속 알고리즘입니다.
목적 Transformer의 memory bottleneck 해소와 대규모 모델 학습 시간 단축
필요성 GPU에서 시퀀스 길이 증가 시 Softmax Attention은 매우 느려지고 메모리 사용량이 급증

FlashAttention은 “O(N²) 복잡도를 O(N²) 시간에 맞게 실질적으로 구현”한 효율화 기술입니다.


2. 핵심 아이디어

개념 설명
수직적 연산 방식 Query별로 연산을 수행하며 메모리를 재활용하는 방식으로 Softmax를 계산
Tiling & Reuse 고정된 CUDA shared memory 공간에 Tile 단위로 QKᵗ를 계산 및 Softmax 누적
On-the-fly Normalization 전체 Row에 대해 나중에 Softmax를 정규화하는 대신 행별 최대값/합을 즉시 계산
Fused Kernel 여러 연산을 하나의 CUDA 커널로 통합하여 중간 저장을 제거함

이 방식은 정확도를 손상시키지 않으면서도 연산 효율성을 극대화합니다.


3. 성능 비교 및 벤치마크

항목 FlashAttention PyTorch 기본 Attention
연산 속도 최대 2~4배 향상 기존 구현 대비 느림
GPU 메모리 사용량 최대 50% 감소 QKᵗ, Softmax, V 저장 필요
시퀀스 길이 확장성 수천~수만 길이에서도 안정 작동 시퀀스 길이 증가 시 OOM 발생 빈도 높음
정확도 손실 없음 없음 (동일 결과 보장)

FlashAttention은 특히 LLM(예: GPT, BERT, T5)에서 효율적인 사전학습 및 추론을 가능하게 합니다.


4. 구조 및 계산 흐름

  1. Query–Key 행렬 QKᵗ를 Tile 단위로 나누어 shared memory에 적재
  2. 각 Tile에서 Softmax 분모 계산 (누적합/최댓값 유지)
  3. Softmax 계산과 동시에 V 행렬 곱 연산 수행 (QKᵗV)
  4. 결과를 global memory로 저장

이 모든 과정을 하나의 CUDA 커널로 처리하여 write/read 병목을 제거합니다.


5. 활용 도구 및 적용 사례

도구 설명 활용 예시
FlashAttention Library 공식 PyTorch/CUDA 구현 제공 GPT, BERT, T5 모델 학습 가속
HuggingFace Transformers FlashAttention과 통합된 backend 지원 LLM 추론 속도 향상
xFormers Meta AI의 고속 Attention 프레임워크 Flash+Block+Sparse 연산 조합
Triton 커스텀 커널 최적화 프레임워크 사용자 정의 attention 변형 구현 가능

LLM 학습 비용 절감과 추론 지연 시간 개선을 위해 주요 플랫폼에서 지원 확대 중입니다.


6. 장점 및 효과

항목 설명
연산 속도 향상 Fused 연산과 memory tiling으로 연산 속도 2~4배 향상
메모리 절약 중간 텐서 저장 생략으로 GPU memory footprint 절감
시퀀스 길이 확장 4K 이상 입력에서도 안정적 연산 가능
정확도 보존 Softmax 연산 순서를 조정하되 수치적 정확성 유지

FlashAttention은 대규모 Transformer 학습의 실질적 한계를 극복한 대표적 사례입니다.


7. 결론

FlashAttention은 Transformer 모델 학습에서 가장 큰 병목 중 하나인 Attention 연산의 메모리 및 속도 문제를 혁신적으로 해결한 알고리즘입니다. 특히 GPU 자원 활용이 중요한 LLM 훈련 및 추론 환경에서 효율성과 정확성을 동시에 제공함으로써, 차세대 모델의 스케일 확장과 연구 생산성에 크게 기여할 것으로 기대됩니다.

728x90
반응형

'Topic' 카테고리의 다른 글

FIM (Fill-In-the-Middle) Pre-training  (0) 2025.05.16
LongNet  (0) 2025.05.16
Mo’s Algorithm  (0) 2025.05.16
Link-Cut Tree  (0) 2025.05.15
OpenTitan  (1) 2025.05.15