JAX는 구글이 2018년경부터 개발한 「수치 계산 + 자동 미분 + GPU/TPU 가속」을 위한 차세대 라이브러리입니다.
NumPy와 거의 같은 API를 제공하면서도 GPU·TPU에서 동작하고, 자동 미분과 JIT 컴파일을 지원합니다.
JAX의 철학은 「함수형 프로그래밍」입니다.
모든 함수를 「입력 → 출력」의 순수 함수로 다루며, 부수 효과를 최소화합니다.
이 덕분에 코드를 「변환(transform)」하기 쉬워, jit(컴파일), grad(미분), vmap(벡터화), pmap(병렬화) 같은 강력한 변환 함수를 자유롭게 결합할 수 있습니다.
예를 들어 「학습 함수를 만든 뒤 jit로 컴파일해 100배 빨라지게 하고, vmap으로 배치 처리하고, pmap으로 8개 GPU에 분산하기」를 코드 한 줄씩 더해 가며 자연스럽게 할 수 있습니다.
PyTorch나 TensorFlow에서는 같은 일을 하려면 별도의 도구·플래그·설정이 필요합니다.
비유하자면 JAX는 「함수의 변신 로봇」과 같습니다.
같은 함수를 빠른 버전·미분 가능한 버전·여러 GPU 동시 실행 버전으로 자유롭게 변환할 수 있고, 이 변환들이 깔끔하게 조합됩니다.
수학적 우아함이 있는 프레임워크입니다.
JAX는 구글 DeepMind가 자기 연구에 적극 사용하며, AlphaFold·Gemini·MaxText 같은 큰 모델들이 JAX로 학습됐습니다.
다만 산업 채택은 PyTorch·TensorFlow보다 적고, 학습 곡선이 가파르다는 평이 있어 입문자에게는 어렵습니다.
Flax·Haiku 같은 고수준 라이브러리가 진입 장벽을 낮추고 있습니다.
한 줄 요약
JAX는 NumPy 같은 API + 자동 미분 + JIT + vmap·pmap 변환을 결합한 함수형 ML 프레임워크입니다.
DeepMind가 적극 사용하며 큰 모델 학습에 강하지만 학습 곡선은 가파릅니다.
더 알아볼 것
- Flax·Haiku — JAX 위의 고수준 라이브러리
- XLA — JAX·TF가 공유하는 컴파일러
- TPU와 JAX의 궁합