개요JAX Pallas는 Google JAX 생태계에서 고성능 커널을 직접 정의할 수 있도록 설계된 Python 기반 DSL(Domain-Specific Language)이다. 기존 XLA 컴파일러에 전적으로 의존하던 방식에서 벗어나, 개발자가 세밀한 메모리 제어 및 병렬 실행 구조를 정의할 수 있도록 지원한다. 특히 Mosaic GPU 및 TPU 아키텍처에 최적화된 저수준 연산을 구현할 수 있으며, Triton과 유사한 프로그래밍 모델을 JAX 내부로 통합한 것이 특징이다.1. 개념 및 정의JAX Pallas는 JAX의 고수준 함수형 인터페이스를 유지하면서도 CUDA 수준의 세밀한 제어를 가능하게 하는 커널 작성 프레임워크이다. JAX의 jit, vmap, pmap과 결합되어 자동 미분 및 벡터화와 자..