dm-haiku: 객체 지향 모델 정의를 가능하게 하는 JAX용 신경망 라이브러리

dm-haiku: 객체 지향 모델 정의를 가능하게 하는 JAX용 신경망 라이브러리

해결하는 문제

JAX는 강력한 수치 계산 라이브러리이지만, 변환(예: jitgrad)을 사용하려면 함수가 순수해야 합니다. Haiku는 개발자가 익숙한 객체 지향 프로그래밍 모델을 사용하여 신경망을 정의할 수 있도록 함으로써 이 문제를 해결하며, 이러한 "불순한" 객체 지향 정의를 JAX가 처리할 수 있는 순수 함수로 자동으로 변환합니다.

작동 방식

Haiku는 객체 지향 설계와 함수적 순수성 사이의 간극을 메우기 위해 두 가지 주요 도구를 제공합니다:

  • hk.Module: 네트워크 레이어와 구성 요소를 정의하는 데 사용되는 Python 객체입니다. 이러한 모듈은 매개변수에 대한 참조와 메서드를 보유하여 사용자가 표준 신경망 라이브러리와 유사한 코드를 작성할 수 있도록 합니다.
  • hk.transform: hk.Module을 사용하는 함수를 init(초기 매개변수 값을 수집함)과 apply(계산을 위해 해당 매개변수를 함수에 다시 주입함)라는 한 쌍의 순수 함수로 변환하는 함수 변환입니다.

배치 정규화(batch normalization)와 같이 내부 가변 상태가 필요한 모델의 경우, Haiku는 매개변수와 상태를 별도로 관리하는 hk.transform_with_state를 제공합니다.

대상 사용자

신경망 구축을 위한 객체 지향 API의 생산성을 원하면서도 JAX의 함수 변환 및 하드웨어 가속의 모든 기능을 필요로 하는 연구자 및 개발자.

주요 특징

  • DeepMind 규모: 대규모 이미지, 언어 및 강화 학습 실험을 위해 DeepMind 연구원들에 의해 테스트되었습니다.
  • 라이브러리, 프레임워크가 아님: 사용자 정의 최적화 도구(optimizer)나 체크포인트 형식을 강요하지 않고 매개변수 및 상태 관리에 집중하는 경량 라이브러리로 설계되었습니다.
  • hk.next_rng_key(): 결정론적 키 시퀀스를 제공하여 JAX 내에서 난수 생성을 단순화합니다.
  • JAX 호환성: 여러 가속기에서 분산 학습을 위해 jax.pmap과 완전히 호환됩니다.

Sources