色哟哟视频在线观看-色哟哟视频在线-色哟哟欧美15最新在线-色哟哟免费在线观看-国产l精品国产亚洲区在线观看-国产l精品国产亚洲区久久

0
  • 聊天消息
  • 系統消息
  • 評論與回復
登錄后你可以
  • 下載海量資料
  • 學習在線課程
  • 觀看技術視頻
  • 寫文章/發帖/加入社區
會員中心
創作中心

完善資料讓更多小伙伴認識你,還能領取20積分哦,立即完善>

3天內不再提示

如何將Flax/JAX模型轉換為TFLite并在原生Android應用中運行呢

Tensorflowers ? 來源:TensorFlow ? 作者:TensorFlow ? 2022-11-02 10:13 ? 次閱讀

在我們之前發布的文章《一個新的 TensorFlow Lite 示例應用:棋盤游戲》中,展示了如何使用 TensorFlow 和 TensorFlow Agents 來訓練強化學習 (RL) agent,使其玩一個簡單的棋盤游戲“Plane Strike”。我們還將訓練后的模型轉換為 TensorFlow Lite,然后將其部署到功能完備的 Android 應用中。本文,我們將演示一種全新路徑:使用 Flax/JAX 訓練相同的強化學習 agent,然后將其部署到我們之前構建的同一款 Android 應用中。

簡單回顧一下游戲規則:我們基于強化學習的 agent 需要根據真人玩家的棋盤位置預測擊打位置,以便能早于真人玩家完成游戲。如需進一步了解游戲規則,請參閱我們之前發布的文章。

23754442-59d4-11ed-a3b6-dac502259ad0.gif

“Plane Strike”游戲演示

背景:JAX 和 TensorFlow

JAX 是一個與 NumPy 類似的內容庫,由 Google Research 部門專為實現高性能計算而開發。JAX 使用 XLA 針對 GPU 和 TPU 優化的程序進行編譯。

JAX

https://github.com/google/jax

XLA

https://tensorflow.google.cn/xla

TPU

https://cloud.google.com/tpu

而 Flax 則是在 JAX 基礎上構建的一款熱門神經網絡庫。研究人員一直在使用 JAX/Flax 來訓練包含數億萬個參數的超大模型(如用于語言理解和生成的 PaLM,或者用于圖像生成的 Imagen),以便充分利用現代硬件

Flax

https://github.com/google/flax

PaLM

https://ai.googleblog.com/2022/04/pathways-language-model-palm-scaling-to.html

Imagen

https://imagen.research.google/

如果您不熟悉 JAX 和 Flax,可以先從 JAX 101 教程和 Flax 入門示例開始。

JAX 101 教程

https://jax.readthedocs.io/en/latest/jax-101/index.html

Flax 入門示例

https://flax.readthedocs.io/en/latest/getting_started.html

2015 年底,TensorFlow 作為 Machine Learning (ML) 內容庫問世,現已發展為一個豐富的生態系統,其中包含用于實現 ML 流水線生產化 (TFX)、數據可視化 (TensorBoard),和將 ML 模型部署到邊緣設備 (TensorFlow Lite) 的工具,以及在網絡瀏覽器上運行的裝置,或能夠執行 JavaScript (TensorFlow.js) 的任何裝置。

TFX

https://tensorflow.google.cn/tfx

TensorBoard

https://tensorboard.dev/

TensorFlow Lite

https://tensorflow.google.cn/lite

TensorFlow.js

https://tensorflow.google.cn/js

在 JAX 或 Flax 中開發的模型也可以利用這一豐富的生態系統。方法是首先將此類模型轉換為 TensorFlow SavedModel 格式,然后使用與它們在 TensorFlow 中原生開發相同的工具。

SavedModel

https://tensorflow.google.cn/guide/saved_model

如果您已經擁有經 JAX 訓練的模型并希望立即進行部署,我們整合了一份資源列表供您參考:

視頻 “使用 TensorFlow Serving 為 JAX 模型提供服務”,展示了如何使用 TensorFlow Serving 部署 JAX 模型。

https://youtu.be/I4dx7OI9FJQ?t=36

文章《借助 TensorFlow.js 在網絡上使用 JAX》,對如何將 JAX 模型轉換為 TFJS,并在網絡應用中運行進行了詳細講解。

https://blog.tensorflow.org/2022/08/jax-on-web-with-tensorflowjs.html

本篇文章演示了如何將 Flax/JAX 模型轉換為 TFLite,并在原生 Android 應用中運行該模型。

總而言之,無論您的部署目標是服務器、網絡還是移動設備,我們都會為您提供相應的幫助。

使用 Flax/JAX 實現游戲 agent

將目光轉回到棋盤游戲。為了實現強化學習 agent,我們將會利用與之前相同的 OpenAI gym 環境。這次,我們將使用 Flax/JAX 訓練相同的策略梯度模型。回想一下,在數學層面上策略梯度的定義是:

OpenAI gym

https://github.com/tensorflow/examples/tree/master/lite/examples/reinforcement_learning/ml/tf_and_jax/gym_planestrike/gym_planestrike/envs

23e88678-59d4-11ed-a3b6-dac502259ad0.png

其中:

T:每段的時步數,各段的時步數可能有所不同

st:時步上的狀態 t

at:時步上的所選操作 t 指定狀態s

πθ:參數為 θ 的策略

R(*):在指定策略下,收集到的獎勵

我們定義了一個 3 層 MLP 作為策略網絡,該網絡可以預測 agent 的下一個擊打位置。

class PolicyGradient(nn.Module):
  """Neural network to predict the next strike position."""


@nn.compact
  def __call__(self, x):
    dtype = jnp.float32
    x = x.reshape((x.shape[0], -1))
    x = nn.Dense(
        features=2 * common.BOARD_SIZE**2, name='hidden1', dtype=dtype)(
           x)
    x = nn.relu(x)
    x = nn.Dense(features=common.BOARD_SIZE**2, name='hidden2', dtype=dtype)(x)
    x = nn.relu(x)
    x = nn.Dense(features=common.BOARD_SIZE**2, name='logits', dtype=dtype)(x)
    policy_probabilities = nn.softmax(x)
    return policy_probabilities

在我們訓練循環的每次迭代中,我們都會使用神經網絡玩一局游戲、收集軌跡信息(游戲棋盤位置、采取的操作和獎勵)、對獎勵進行折扣,然后使用相應軌跡訓練模型。

for i in tqdm(range(iterations)):
   predict_fn = functools.partial(run_inference, params)
   board_log, action_log, result_log = common.play_game(predict_fn)
   rewards = common.compute_rewards(result_log)
   optimizer, params, opt_state = train_step(optimizer, params, opt_state,
                                             board_log, action_log, rewards)

在 train_step() 方法中,我們首先會使用軌跡計算損失,然后使用 jax.grad() 計算梯度,最后,使用 Optax(用于 JAX 的梯度處理和優化庫)來更新模型參數。

Optax

https://github.com/deepmind/optax

def compute_loss(logits, labels, rewards):
  one_hot_labels = jax.nn.one_hot(labels, num_classes=common.BOARD_SIZE**2)
  loss = -jnp.mean(
      jnp.sum(one_hot_labels * jnp.log(logits), axis=-1) * jnp.asarray(rewards))
  return loss


def train_step(model_optimizer, params, opt_state, game_board_log,
              predicted_action_log, action_result_log):
"""Run one training step."""

  def loss_fn(model_params):
    logits = run_inference(model_params, game_board_log)
    loss = compute_loss(logits, predicted_action_log, action_result_log)
    return loss

  def compute_grads(params):
    return jax.grad(loss_fn)(params)

  grads = compute_grads(params)
  updates, opt_state = model_optimizer.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  return model_optimizer, params, opt_state


@jax.jit
def run_inference(model_params, board):
  logits = PolicyGradient().apply({'params': model_params}, board)
  return logits

這就是訓練循環。如下圖所示,我們可以在 TensorBoard 中觀察訓練進度;其中,我們使代理指標“game_length”(完成游戲所需的步驟數)來跟蹤進度:若 agent 變得更聰明,它便能以更少的步驟完成游戲。

23f8d758-59d4-11ed-a3b6-dac502259ad0.png

將 Flax/JAX 模型轉換為

TensorFlow Lite 并與

Android 應用集成

完成模型訓練后,我們使用 jax2tf(一款 TensorFlow-JAX 互操作工具),將 JAX 模型轉換為 TensorFlow concrete function。最后一步是調用 TensorFlow Lite 轉換器來將 concrete function 轉換為 TFLite 模型。

jax2tf

https://github.com/google/jax/tree/main/jax/experimental/jax2tf

# Convert to tflite model
 model = PolicyGradient()
 jax_predict_fn = lambda input: model.apply({'params': params}, input)


 tf_predict = tf.function(
     jax2tf.convert(jax_predict_fn, enable_xla=False),
     input_signature=[
         tf.TensorSpec(
             shape=[1, common.BOARD_SIZE, common.BOARD_SIZE],
             dtype=tf.float32,
             name='input')
     ],
     autograph=False,
 )


 converter = tf.lite.TFLiteConverter.from_concrete_functions(
     [tf_predict.get_concrete_function()], tf_predict)


 tflite_model = converter.convert()


 # Save the model
 with open(os.path.join(modeldir, 'planestrike.tflite'), 'wb') as f:
   f.write(tflite_model)

經 JAX 轉換的 TFLite 模型與任何經 TensorFlow 訓練的 TFLite 模型會有完全一致的行為。您可以使用 Netron 進行可視化:

242392fe-59d4-11ed-a3b6-dac502259ad0.png

使用 Netron 對 Flax/JAX 轉換的 TFLite 模型進行可視化

我們可以使用與之前完全一樣的 Java 代碼來調用模型并獲取預測結果。

convertBoardStateToByteBuffer(board);
tflite.run(boardData, outputProbArrays);
float[] probArray = outputProbArrays[0];
int agentStrikePosition = -1;
float maxProb = 0;
for (int i = 0; i < probArray.length; i++) {
  int x = i / Constants.BOARD_SIZE;
  int y = i % Constants.BOARD_SIZE;
  if (board[x][y] == BoardCellStatus.UNTRIED && probArray[i] > maxProb) {
    agentStrikePosition = i;
    maxProb = probArray[i];
  }
}

總結

本文詳細介紹了如何使用 Flax/JAX 訓練簡單的強化學習模型、利用 jax2tf 將其轉換為 TensorFlow Lite,以及將轉換后的模型集成到 Android 應用。

現在,您已經了解了如何使用 Flax/JAX 構建神經網絡模型,以及如何利用強大的 TensorFlow 生態系統,在幾乎任何您想要的位置部署模型。我們十分期待看到您使用 JAX 和 TensorFlow 構建出色應用!





審核編輯:劉清

聲明:本文內容及配圖由入駐作者撰寫或者入駐合作網站授權轉載。文章觀點僅代表作者本人,不代表電子發燒友網立場。文章及其配圖僅供工程師學習之用,如有內容侵權或者其他違規問題,請聯系本站處理。 舉報投訴
  • 神經網絡
    +關注

    關注

    42

    文章

    4771

    瀏覽量

    100715
  • TPU
    TPU
    +關注

    關注

    0

    文章

    140

    瀏覽量

    20720
  • MLP
    MLP
    +關注

    關注

    0

    文章

    57

    瀏覽量

    4241

原文標題:使用 JAX 構建強化學習 agent,并借助 TensorFlow Lite 將其部署到 Android 應用中

文章出處:【微信號:tensorflowers,微信公眾號:Tensorflowers】歡迎添加關注!文章轉載請注明出處。

收藏 人收藏

    評論

    相關推薦

    使用電腦上tensorflow創建的模型轉換為tflite格式了,導入后進度條反復出現0-100%變化,為什么?

    使用電腦上tensorflow創建的模型轉換為tflite格式了,導入后,進度條反復出現0-100%變化,卡了一個晚上了還沒分析好?
    發表于 03-19 06:20

    如何將采樣位移轉換為采樣速度

    我是新手,在Labview編程如何將采樣的位移轉換為速度?求圖解,謝謝
    發表于 04-25 14:56

    如何將秒數轉換為時間字符串?

    請問如何將數值型秒數轉換為時間字符串?比如3600s轉換為01:00:00
    發表于 03-30 13:15

    如何將傳統ANN轉換為SNN?

    SNN和ANN的區別是什么?如何將傳統ANN轉換為SNN?
    發表于 09-28 06:15

    如何將觸控芯片的IIC接口轉換為USB接口

    CH554是什么?CH554如何實現數據轉換如何將觸控芯片的IIC接口轉換為USB接口
    發表于 02-24 07:54

    EIQ onnx模型轉換為tf-lite失敗怎么解決?

    問題: 而我們需要您幫助我們回答這些問題:a) Dose eIQ(版本 2.7.12)支持 onnx 模型轉換為 tflte 格式?(文件見附件)b) 找不到float16 的量化選項,你知道
    發表于 03-31 08:03

    如何在MIMXRT1064評估套件上部署tflite模型

    我有一個嬰兒哭聲檢測 tflite (tensorflow lite) 文件,其中包含模型本身。我如何將模型部署到 MIMXRT1064-evk 以通過 MCUXpresso IDE
    發表于 04-06 06:24

    如何將DS_CNN_S.pb轉換為ds_cnn_s.tflite

    用于圖像分類(eIQ tensflowlite 庫)。從廣義上講,我正在尋找該腳本,您可能已經使用該腳本 DS_CNN_S.pb 轉換為 ds_cnn_s.tflite我能夠查看兩個模型
    發表于 04-19 06:11

    Pytorch模型轉換為DeepViewRT模型時出錯怎么解決?

    我最終可以在 i.MX 8M Plus 處理器上部署 .rtm 模型。 我遵循了 本指南,我 Pytorch 模型轉換為 ONNX 模型
    發表于 06-09 06:42

    如何將Detectron2和Layout-LM模型轉換為OpenVINO中間表示(IR)和使用CPU插件進行推斷?

    無法確定如何將 Detectron2* 和 Layout-LM* 模型轉換為OpenVINO?中間表示 (IR) 和使用 CPU 插件進行推斷。
    發表于 08-15 06:23

    數學原理:如何將ADC代碼轉換為電壓(第1篇)

    許多初步了解模數轉換器(ADC)的人想知道如何將ADC代碼轉換為電壓。或者,他們的問題是針對特定應用,例如:如何將ADC代碼轉換回物理量,如
    發表于 04-18 03:30 ?4044次閱讀

    如何將Altera的SDC約束轉換為Xilinx XDC約束

    了解如何將Altera的SDC約束轉換為Xilinx XDC約束,以及需要更改或修改哪些約束以使Altera的約束適用于Vivado設計軟件。
    的頭像 發表于 11-27 07:17 ?5123次閱讀

    Android中使用TFLite c++部署

    之前的文章,我們跟大家介紹過如何使用NNAPI來加速TFLite-Android的inference(可參考使用NNAPI加速android-tflite的Mobilenet分類器...
    發表于 02-07 11:57 ?7次下載
    在<b class='flag-5'>Android</b>中使用<b class='flag-5'>TFLite</b> c++部署

    如何將簡單的汽車轉換為無線遙控汽車

    電子發燒友網站提供《如何將簡單的汽車轉換為無線遙控汽車.zip》資料免費下載
    發表于 10-21 14:51 ?2次下載
    <b class='flag-5'>如何將</b>簡單的汽車<b class='flag-5'>轉換為</b>無線遙控汽車

    如何將Android代碼轉換成JS代碼運行

    Autojs這個工具,因為它本身是使用的Rhino引擎開發的,因此它可以把Android代碼轉換成JavaScript語法的代碼來運行,Autojs提供了幾個相關的方法來輔助
    的頭像 發表于 03-03 14:05 ?2620次閱讀
    主站蜘蛛池模板: 国产精品悠悠久久人妻精品| 国产美女一区二区| 内地同志男16china16| 乳色吐息未增删樱花ED在线观看| 我年轻漂亮的继坶2中字在线播放 我们中文在线观看免费完整版 | 免费国产在线观看| 日韩少妇爆乳无码专区| 18禁黄久久久AAA片| 精品国产三级a| 亚洲欧美自拍明星换脸| 国产人妻麻豆蜜桃色| 甜性涩爱下载| 丰满的美女射精动态图| 欧美日韩视频一区二区三区| 97精品国产亚洲AV高清| 可以看的黄页的网站| 中文字幕欧美一区| 精品亚洲一区二区三区在线播放| 亚洲国产成人久久一区www妖精| 国产成人精品综合在线| 天天操狠狠操夜夜操| 国产跪地吃黄金喝圣水合集| 爽爽影院免费观看| 国产色精品久久人妻无码看片软件 | no视频在线观看| 青青视频国产色偷偷| 不良网站进入窗口软件下载免费 | 亚洲国产精品第一影院在线观看 | 欧美写真视频一区| 粗好大用力好深快点漫画| 色欲av蜜臀av高清| 国产一区内射最近更新| 亚洲欧美成人| 美女张开让男生桶| sao虎影院桃红视频在线观看| 日韩亚洲欧洲在线rrrr片| 国产精品自在自线亚洲| 影视先锋男人无码在线| 欧美高跟镣铐bdsm视频| 国产精品96久久久久久AV网址| 亚洲 视频 在线 国产 精品 |