聯(lián)邦學習是一種機器學習設置,允許多個客戶端(即移動設備或者整個組織,取決于正在參與的任務)在一個中央服務器的編排下,協(xié)同訓練同一個模型,同時還能保持訓練數(shù)據(jù)的離散性。例如,通過聯(lián)邦學習,可以基于永遠不會從移動設備中消失的用戶數(shù)據(jù)訓練虛擬鍵盤語言模型。
要實現(xiàn)這點,聯(lián)邦學習算法首先需要初始化服務器中的模型,然后完成以下對于每一輪訓練而言都非常關鍵的三步:
1. 服務器將模型發(fā)送到一組采樣客戶端。
2. 這些采樣客戶端在本地數(shù)據(jù)中訓練模型。
3. 訓練完成之后,客戶端將更新后的模型發(fā)送到服務器,然后服務器將所有這些模型匯總在一起。
隨著人們對隱私和安全的日益注重,聯(lián)邦學習已成為一個尤為活躍的研究領域。對于這個日新月異的領域,能夠輕松將想法轉換為代碼、快速迭代,以及比較和復制現(xiàn)有基線的重要性不言而喻。
日新月異的領域
https://research.google/pubs/pub49232/
因此,我們很高興為大家介紹 FedJAX。FedJAX 是一個基于 JAX 的開源庫,適用于注重研究易用性的聯(lián)邦學習模擬。FedJAX 擁有適用于執(zhí)行聯(lián)邦算法、預打包的數(shù)據(jù)集、模型和算法以及高模擬速度的簡單基本模塊,旨在讓研究員能夠更快速、更容易地開發(fā)和評估聯(lián)邦算法。
FedJAX
https://github.com/google/fedjax
JAX
https://github.com/google/jax
在這篇文章中,我們將討論 FedJAX 的庫結構和內容。我們會證明,在 TPU 中,F(xiàn)edJAX 可通過 EMNIST 數(shù)據(jù)集的聯(lián)合平均,在幾分鐘內就能訓練完模型。而通過 Stack Overflow 數(shù)據(jù)集的標準超參數(shù) (Hyperparameter),則需要將近 1 小時。
EMNIST
https://github.com/google/fedjax/blob/main/fedjax/datasets/emnist.py
聯(lián)合平均
https://fedjax.readthedocs.io/en/latest/fedjax.algorithms.html#module-fedjax.algorithms.fed_avg
Stack Overflow
https://github.com/google/fedjax/blob/main/fedjax/datasets/stackoverflow.py
庫結構
FedJAX 注重易用性,因此僅引進了少量新概念。使用 FedJAX 編寫的代碼與學術論文用于描述新穎算法的偽代碼類似,因此極易上手。除此之外,雖然 FedJAX 提供了聯(lián)邦學習的基本模塊,但用戶可以將其替換為最基本的實現(xiàn)(僅使用 NumPy 和 JAX),并且仍然可以將整體訓練速度保持在一個合理的區(qū)間。
與學術論文用于描述新穎算法的偽代碼類似
https://github.com/google/fedjax/blob/main/README.md#quickstart
NumPy
https://numpy.org/
包含的數(shù)據(jù)集和模型
在當前聯(lián)邦學習研究領域,存在各種各樣常用的數(shù)據(jù)集和模型,例如圖像識別 (Image recognition)、語言建模 (Language modeling) 等。越來越多這樣的數(shù)據(jù)集和模型無需安裝即可直接用于 FedJAX,因此用戶無需從頭開始編寫預處理數(shù)據(jù)集和模型。這不僅有利于對不同的聯(lián)邦算法進行有效比較,還加速了新算法的開發(fā)。
目前,F(xiàn)edJAX 與以下數(shù)據(jù)集和示例模型一起打包:
EMNIST-62,一項字符識別任務
https://github.com/google/fedjax/blob/main/fedjax/datasets/emnist.py
Shakespeare,一項下一字符預測任務
https://github.com/google/fedjax/blob/main/fedjax/datasets/shakespeare.py
Stack Overflow,一項下一字詞預測任務
https://github.com/google/fedjax/blob/main/fedjax/datasets/stackoverflow.py
除了以上標準設置,F(xiàn)edJAX 還提供用于創(chuàng)建新數(shù)據(jù)集和模型的工具,這些新數(shù)據(jù)集和模型可以與庫的其余內容共同使用。
工具
https://fedjax.readthedocs.io/en/latest/fedjax.html#federated-data
此外,F(xiàn)edJAX 支持聯(lián)合平均的標準實現(xiàn),也支持用于在分散式示例上訓練共享模型的其他聯(lián)邦算法,例如自適應聯(lián)邦優(yōu)化器、不可知聯(lián)合平均以及 Mime,從而讓比較和評估現(xiàn)有算法變得更加簡單。
自適應聯(lián)邦優(yōu)化器
https://fedjax.readthedocs.io/en/latest/fedjax.algorithms.html#module-fedjax.algorithms.fed_avg
不可知聯(lián)合平均
https://fedjax.readthedocs.io/en/latest/fedjax.algorithms.html#module-fedjax.algorithms.agnostic_fed_avg
Mime
https://fedjax.readthedocs.io/en/latest/fedjax.algorithms.html#module-fedjax.algorithms.mime
性能評估
我們在兩項任務上對自適應聯(lián)合平均的標準 FedJAX 實現(xiàn)進行了基準測試:圖像識別任務(測試聯(lián)邦 EMNIST-62 數(shù)據(jù)集)和下一字詞預測任務(測試 Stack Overflow 數(shù)據(jù)集)。聯(lián)邦 EMNIST-62 數(shù)據(jù)集較小,由 3400 名用戶和他們創(chuàng)建的示例(共 62 個拉丁字母數(shù)字字符)構成;而 Stack Overflow 數(shù)據(jù)集較大,由數(shù)百萬問題和答案構成(這些問題和答案來自于擁有成千上萬名用戶的 Stack Overflow 論壇)。
自適應聯(lián)合平均
https://openreview.net/pdf?id=LkFG3lB13U5
聯(lián)邦 EMNIST-62 數(shù)據(jù)集
https://github.com/google/fedjax/blob/main/fedjax/datasets/emnist.py
Stack Overflow 數(shù)據(jù)集
https://github.com/google/fedjax/blob/main/fedjax/datasets/stackoverflow.py
我們在專門用于機器學習的各種硬件上測量性能。對于聯(lián)邦 EMNIST-62,我們在 GPU (NVIDIA V100) 和 TPU(Google TPU v2 上的 1 個 TensorCore)加速器上對單一模型進行了 1500 輪訓練(每輪 10 個客戶端)。
對于 Stack Overflow,我們在 GPU (NVIDIA V100)、單核 TPU(Google TPU v2 上 1 個 TensorCore)及多核 TPU(Google TPU v2 上 8 個 TensorCore)上對單一模型進行了 1500 輪訓練(每輪 50 個客戶端)。其中,在 GPU 上使用 jax.jit,在單核 TPU 上僅使用 jax.jit,而在多核 TPU 上使用 jax.pmap。在下方圖表中,我們記錄了每輪訓練的平均完成時間、完整評估測試數(shù)據(jù)所需時間以及整體執(zhí)行時間(整體執(zhí)行包含訓練和完整評估)。
通過標準超參數(shù)和 TPUs,聯(lián)邦 EMNIST-62 的整個實驗可以在幾分鐘之內完成,而 Stack Overflow.的實驗需要 1 小時左右的時間。
我們還評估了隨著每輪客戶端數(shù)量增加之后的 Stack Overflow 平均每輪訓練時長。通過比較圖表上 8 核 TPU 與單核 TPU 的平均每輪訓練時長,我們很容易就能發(fā)現(xiàn),如果每輪參與的客戶端數(shù)量較多,則使用多核 TPU 能極大縮短運行時間(對微分化的不公開學習等應用來說非常有幫助)。
微分化的不公開學習
https://openreview.net/forum?id=BJ0hF1Z0b
結論和未來研究方向
在這篇文章中,我們介紹了 FedJAX 這種適用于研究、速度較快且簡單易用的聯(lián)邦學習模擬庫。我們希望 FedJAX 能推動聯(lián)邦學習的深入研究,同時引起人們對于該領域的更多關注。未來,我們計劃繼續(xù)發(fā)展現(xiàn)有算法集、聚合機制、數(shù)據(jù)集和模型。
歡迎各位隨時查閱我們的教程筆記本,或者親自體驗 FedJAX!
教程筆記本
https://fedjax.readthedocs.io/en/latest/
親自體驗 FedJAX
https://github.com/google/fedjax/blob/main/examples
若想進一步了解 FedJAX 及其與 Tensorflow Federated 等平臺的關系,請參閱我們的論文、README 或常見問題解答。
責任編輯:haq
-
服務器
+關注
關注
12文章
9265瀏覽量
85787 -
機器學習
+關注
關注
66文章
8430瀏覽量
132858
原文標題:FedJAX:使用 JAX 進行聯(lián)邦學習模擬
文章出處:【微信號:tensorflowers,微信公眾號:Tensorflowers】歡迎添加關注!文章轉載請注明出處。
發(fā)布評論請先 登錄
相關推薦
評論