TensorFlow有了替代品,竟然還是谷歌自己做出來的?這其實是TensorFlow的一個簡化庫,名為JAX,可以支持部分TensorFlow的功能,但是比TensorFlow更加簡潔易用。
什么?TensorFlow 有了替代品?什么?竟然還是谷歌自己做出來的?先別慌,從各種意義上來說,這個所謂的 “替代品” 其實是 TensorFlow 的一個簡化庫,名為JAX,結合 Autograd 和 XLA,可以支持部分 TensorFlow 的功能,但是比 TensorFlow 更加簡潔易用。
雖然還不至于替代 TensorFlow,但已經有 Reddit 網友對 JAX 寄予厚望,并表示“早就期待能有一個可以直接調用 Numpy API 接口的庫了!”,“希望它可以取代 TensorFlow!”。
JAX 結合了 Autograd 和 XLA,是專為高性能機器學習研究打造的產品。
有了新版本的Autograd,JAX 能夠自動對 Python 和 NumPy 的自帶函數求導,支持循環、分支、遞歸、閉包函數求導,而且可以求三階導數。它支持自動模式反向求導(也就是反向傳播)和正向求導,且二者可以任意組合成任何順序。
JAX 的創新之處在于,它基于XLA在 GPU 和 TPU 上編譯和運行 NumPy 程序。默認情況下,編譯是在底層進行的,庫調用能夠及時編譯和執行。但是 JAX 還允許使用單一函數 API jit將自己的 Python 函數及時編譯成經過 XLA 優化的內核。編譯和自動求導可以任意組合,因此可以在不脫離 Python 環境的情況下實現復雜算法并獲得最優性能。
JAX 最初由 Matt Johnson、Roy Frostig、Dougal Maclaurin 和 Chris Leary 發起,他們均任職于谷歌大腦團隊。在 GitHub 的說明文檔中,作者明確表示:JAX 目前還只是一個研究項目,不是谷歌的官方產品,因此可能會有一些 bug。從作者的 GitHub 簡介來看,這應該是谷歌大腦正在嘗試的新項目,在同一個 GitHub 目錄下的開源項目還包括 8 月份在業內引起熱議的強化學習框架 Dopamine。
以下是 JAX 的簡單使用示例。
GitHub 項目傳送門:https://github.com/google/JAX
有關具體的安裝和簡單的入門指導大家可以在 GitHub 中自行查看,在此不做過多贅述。
JAX 庫的實現原理
機器學習中的編程是關于函數的表達和轉換。轉換包括自動微分、加速器編譯和自動批處理。像 Python 這樣的高級語言非常適合表達函數,但是通常使用者只能應用它們。我們無法訪問它們的內部結構,因此無法執行轉換。
JAX 可以用于專門化高級Python+NumPy函數,并將其轉換為可轉換的表示形式,然后再提升為 Python 函數。
JAX 通過跟蹤專門處理 Python 函數。跟蹤一個函數意味著:監視應用于其輸入,以產生其輸出的所有基本操作,并在有向無環圖 (DAG) 中記錄這些操作及其之間的數據流。為了執行跟蹤,JAX 包裝了基本的操作,就像基本的數字內核一樣,這樣一來,當調用它們時,它們就會將自己添加到執行的操作列表以及輸入和輸出中。為了跟蹤這些原語之間的數據流,跟蹤的值被包裝在 Tracer 類的實例中。
當 Python 函數被提供給 grad 或 jit 時,它被包裝起來以便跟蹤并返回。當調用包裝的函數時,我們將提供的具體參數抽象到 AbstractValue 類的實例中,將它們框起來用于跟蹤跟蹤器類的實例,并對它們調用函數。
抽象參數表示一組可能的值,而不是特定的值:例如,jit 將 ndarray 參數抽象為抽象值,這些值表示具有相同形狀和數據類型的所有 ndarray。相反,grad 抽象 ndarray 參數來表示底層值的無窮小鄰域。通過在這些抽象值上跟蹤 Python 函數,我們確保它足夠專門化,以便轉換是可處理的,并且它仍然足夠通用,以便轉換后的結果是有用的,并且可能是可重用的。然后將這些轉換后的函數提升回 Python 可調用函數,這樣就可以根據需要跟蹤并再次轉換它們。
JAX 跟蹤的基本函數大多與 XLA HLO 1:1 對應,并在 lax.py 中定義。這種 1:1 的對應關系使得到 XLA 的大多數轉換基本上都很簡單,并且確保我們只有一小組原語來覆蓋其他轉換,比如自動微分。 jax.numpy 層是用純 Python 編寫的,它只是用 LAX 函數 (以及我們已經編寫的其他 numpy 函數) 表示 numpy 函數。這使得 jax.numpy 易于延展。
當你使用 jax.numpy 時,底層 LAX 原語是在后臺進行 jit 編譯的,允許你在加速器上執行每個原語操作的同時編寫不受限制的 Python+ numpy 代碼。
但是 JAX 可以做更多的事情:你可以在越來越大的函數上使用jit來進行端到端編譯和優化,而不僅僅是編譯和調度到一組固定的單個原語。例如,可以編譯整個網絡,或者編譯整個梯度計算和優化器更新步驟,而不僅僅是編譯和調度卷積運算。
折衷之處是,jit 函數必須滿足一些額外的專門化需求:因為我們希望編譯專門針對形狀和數據類型的跟蹤,但不是專門針對具體值的跟蹤,所以 jit 裝飾器下的 Python 代碼必須適用于抽象值。如果我們嘗試在一個抽象的 x 上求 x >0 的值,結果是一個抽象的值,表示集合 {True, False},所以 Python 分支就像 if x > 0 會引起報錯。
有關使用 jit 的更多要求,請參見:https://github.com/google/jax#whats-supported
好消息是,jit 是可選的:JAX 庫在后臺對單個操作和函數使用 jit,允許編寫不受限制的 Python+Numpy,同時仍然使用硬件加速器。但是,當你希望最大化性能時,通常可以在自己的代碼中使用 jit 編譯和端到端優化更大的函數。
后續計劃
目前項目小組還將對以下幾項做更多嘗試和更新:
完善說明文檔
支持 Cloud TPU
支持多 GPU 和多 TPU
支持完整的 NumPy 功能和部分 SciPy 功能
全面支持 vmap
加速
降低 XLA 函數調度開銷
線性代數例程(CPU 上的 MKL 和 GPU 上的 MAGMA)
高效自動微分原語cond和while
有關 JAX 庫的介紹大致如此。
-
谷歌
+關注
關注
27文章
6171瀏覽量
105494 -
機器學習
+關注
關注
66文章
8422瀏覽量
132742 -
tensorflow
+關注
關注
13文章
329瀏覽量
60537
原文標題:要替代 TensorFlow?谷歌開源機器學習庫 JAX
文章出處:【微信號:AI_era,微信公眾號:新智元】歡迎添加關注!文章轉載請注明出處。
發布評論請先 登錄
相關推薦
評論