在我們之前發布的文章《一個新的 TensorFlow Lite 示例應用:棋盤游戲》中,展示了如何使用 TensorFlow 和 TensorFlow Agents 來訓練強化學習 (RL) agent,使其玩一個簡單的棋盤游戲“Plane Strike”。我們還將訓練后的模型轉換為 TensorFlow Lite,然后將其部署到功能完備的 Android 應用中。本文,我們將演示一種全新路徑:使用 Flax/JAX 訓練相同的強化學習 agent,然后將其部署到我們之前構建的同一款 Android 應用中。
簡單回顧一下游戲規則:我們基于強化學習的 agent 需要根據真人玩家的棋盤位置預測擊打位置,以便能早于真人玩家完成游戲。如需進一步了解游戲規則,請參閱我們之前發布的文章。
“Plane Strike”游戲演示
背景:JAX 和 TensorFlow
JAX 是一個與 NumPy 類似的內容庫,由 Google Research 部門專為實現高性能計算而開發。JAX 使用 XLA 針對 GPU 和 TPU 優化的程序進行編譯。
JAX
https://github.com/google/jax
XLA
https://tensorflow.google.cn/xla
TPU
https://cloud.google.com/tpu
而 Flax 則是在 JAX 基礎上構建的一款熱門神經網絡庫。研究人員一直在使用 JAX/Flax 來訓練包含數億萬個參數的超大模型(如用于語言理解和生成的 PaLM,或者用于圖像生成的 Imagen),以便充分利用現代硬件。
Flax
https://github.com/google/flax
PaLM
https://ai.googleblog.com/2022/04/pathways-language-model-palm-scaling-to.html
Imagen
https://imagen.research.google/
如果您不熟悉 JAX 和 Flax,可以先從 JAX 101 教程和 Flax 入門示例開始。
JAX 101 教程
https://jax.readthedocs.io/en/latest/jax-101/index.html
Flax 入門示例
https://flax.readthedocs.io/en/latest/getting_started.html
2015 年底,TensorFlow 作為 Machine Learning (ML) 內容庫問世,現已發展為一個豐富的生態系統,其中包含用于實現 ML 流水線生產化 (TFX)、數據可視化 (TensorBoard),和將 ML 模型部署到邊緣設備 (TensorFlow Lite) 的工具,以及在網絡瀏覽器上運行的裝置,或能夠執行 JavaScript (TensorFlow.js) 的任何裝置。
TFX
https://tensorflow.google.cn/tfx
TensorBoard
https://tensorboard.dev/
TensorFlow Lite
https://tensorflow.google.cn/lite
TensorFlow.js
https://tensorflow.google.cn/js
在 JAX 或 Flax 中開發的模型也可以利用這一豐富的生態系統。方法是首先將此類模型轉換為 TensorFlow SavedModel 格式,然后使用與它們在 TensorFlow 中原生開發相同的工具。
SavedModel
https://tensorflow.google.cn/guide/saved_model
如果您已經擁有經 JAX 訓練的模型并希望立即進行部署,我們整合了一份資源列表供您參考:
視頻 “使用 TensorFlow Serving 為 JAX 模型提供服務”,展示了如何使用 TensorFlow Serving 部署 JAX 模型。
https://youtu.be/I4dx7OI9FJQ?t=36
文章《借助 TensorFlow.js 在網絡上使用 JAX》,對如何將 JAX 模型轉換為 TFJS,并在網絡應用中運行進行了詳細講解。
https://blog.tensorflow.org/2022/08/jax-on-web-with-tensorflowjs.html
本篇文章演示了如何將 Flax/JAX 模型轉換為 TFLite,并在原生 Android 應用中運行該模型。
總而言之,無論您的部署目標是服務器、網絡還是移動設備,我們都會為您提供相應的幫助。
使用 Flax/JAX 實現游戲 agent
將目光轉回到棋盤游戲。為了實現強化學習 agent,我們將會利用與之前相同的 OpenAI gym 環境。這次,我們將使用 Flax/JAX 訓練相同的策略梯度模型。回想一下,在數學層面上策略梯度的定義是:
OpenAI gym
https://github.com/tensorflow/examples/tree/master/lite/examples/reinforcement_learning/ml/tf_and_jax/gym_planestrike/gym_planestrike/envs
其中:
T:每段的時步數,各段的時步數可能有所不同
st:時步上的狀態 t
at:時步上的所選操作 t 指定狀態s
πθ:參數為 θ 的策略
R(*):在指定策略下,收集到的獎勵
我們定義了一個 3 層 MLP 作為策略網絡,該網絡可以預測 agent 的下一個擊打位置。
class PolicyGradient(nn.Module): """Neural network to predict the next strike position.""" @nn.compact def __call__(self, x): dtype = jnp.float32 x = x.reshape((x.shape[0], -1)) x = nn.Dense( features=2 * common.BOARD_SIZE**2, name='hidden1', dtype=dtype)( x) x = nn.relu(x) x = nn.Dense(features=common.BOARD_SIZE**2, name='hidden2', dtype=dtype)(x) x = nn.relu(x) x = nn.Dense(features=common.BOARD_SIZE**2, name='logits', dtype=dtype)(x) policy_probabilities = nn.softmax(x) return policy_probabilities
在我們訓練循環的每次迭代中,我們都會使用神經網絡玩一局游戲、收集軌跡信息(游戲棋盤位置、采取的操作和獎勵)、對獎勵進行折扣,然后使用相應軌跡訓練模型。
for i in tqdm(range(iterations)): predict_fn = functools.partial(run_inference, params) board_log, action_log, result_log = common.play_game(predict_fn) rewards = common.compute_rewards(result_log) optimizer, params, opt_state = train_step(optimizer, params, opt_state, board_log, action_log, rewards)
在 train_step() 方法中,我們首先會使用軌跡計算損失,然后使用 jax.grad() 計算梯度,最后,使用 Optax(用于 JAX 的梯度處理和優化庫)來更新模型參數。
Optax
https://github.com/deepmind/optax
def compute_loss(logits, labels, rewards): one_hot_labels = jax.nn.one_hot(labels, num_classes=common.BOARD_SIZE**2) loss = -jnp.mean( jnp.sum(one_hot_labels * jnp.log(logits), axis=-1) * jnp.asarray(rewards)) return loss def train_step(model_optimizer, params, opt_state, game_board_log, predicted_action_log, action_result_log): """Run one training step.""" def loss_fn(model_params): logits = run_inference(model_params, game_board_log) loss = compute_loss(logits, predicted_action_log, action_result_log) return loss def compute_grads(params): return jax.grad(loss_fn)(params) grads = compute_grads(params) updates, opt_state = model_optimizer.update(grads, opt_state) params = optax.apply_updates(params, updates) return model_optimizer, params, opt_state @jax.jit def run_inference(model_params, board): logits = PolicyGradient().apply({'params': model_params}, board) return logits
這就是訓練循環。如下圖所示,我們可以在 TensorBoard 中觀察訓練進度;其中,我們使代理指標“game_length”(完成游戲所需的步驟數)來跟蹤進度:若 agent 變得更聰明,它便能以更少的步驟完成游戲。
將 Flax/JAX 模型轉換為
TensorFlow Lite 并與
Android 應用集成
完成模型訓練后,我們使用 jax2tf(一款 TensorFlow-JAX 互操作工具),將 JAX 模型轉換為 TensorFlow concrete function。最后一步是調用 TensorFlow Lite 轉換器來將 concrete function 轉換為 TFLite 模型。
jax2tf
https://github.com/google/jax/tree/main/jax/experimental/jax2tf
# Convert to tflite model model = PolicyGradient() jax_predict_fn = lambda input: model.apply({'params': params}, input) tf_predict = tf.function( jax2tf.convert(jax_predict_fn, enable_xla=False), input_signature=[ tf.TensorSpec( shape=[1, common.BOARD_SIZE, common.BOARD_SIZE], dtype=tf.float32, name='input') ], autograph=False, ) converter = tf.lite.TFLiteConverter.from_concrete_functions( [tf_predict.get_concrete_function()], tf_predict) tflite_model = converter.convert() # Save the model with open(os.path.join(modeldir, 'planestrike.tflite'), 'wb') as f: f.write(tflite_model)
經 JAX 轉換的 TFLite 模型與任何經 TensorFlow 訓練的 TFLite 模型會有完全一致的行為。您可以使用 Netron 進行可視化:
使用 Netron 對 Flax/JAX 轉換的 TFLite 模型進行可視化
我們可以使用與之前完全一樣的 Java 代碼來調用模型并獲取預測結果。
convertBoardStateToByteBuffer(board); tflite.run(boardData, outputProbArrays); float[] probArray = outputProbArrays[0]; int agentStrikePosition = -1; float maxProb = 0; for (int i = 0; i < probArray.length; i++) { int x = i / Constants.BOARD_SIZE; int y = i % Constants.BOARD_SIZE; if (board[x][y] == BoardCellStatus.UNTRIED && probArray[i] > maxProb) { agentStrikePosition = i; maxProb = probArray[i]; } }
總結
本文詳細介紹了如何使用 Flax/JAX 訓練簡單的強化學習模型、利用 jax2tf 將其轉換為 TensorFlow Lite,以及將轉換后的模型集成到 Android 應用。
現在,您已經了解了如何使用 Flax/JAX 構建神經網絡模型,以及如何利用強大的 TensorFlow 生態系統,在幾乎任何您想要的位置部署模型。我們十分期待看到您使用 JAX 和 TensorFlow 構建出色應用!
審核編輯:劉清
-
神經網絡
+關注
關注
42文章
4771瀏覽量
100715 -
TPU
+關注
關注
0文章
140瀏覽量
20720 -
MLP
+關注
關注
0文章
57瀏覽量
4241
原文標題:使用 JAX 構建強化學習 agent,并借助 TensorFlow Lite 將其部署到 Android 應用中
文章出處:【微信號:tensorflowers,微信公眾號:Tensorflowers】歡迎添加關注!文章轉載請注明出處。
發布評論請先 登錄
相關推薦
評論