Keras和PyTorch變得極為流行,主要原因是它們比TensorFlow更容易使用。本文對比了Keras和PyTorch四個方面的不同,讀者可以針對自己的任務(wù)來選擇。
對于許多科學(xué)家、工程師和開發(fā)人員來說,TensorFlow是他們的第一個深度學(xué)習(xí)框架。但indus.ai公司機(jī)器學(xué)習(xí)工程師George Seif認(rèn)為,TF并不是非常的用戶友好。
相比TF,Seif認(rèn)為Keras和PyTorch比TensorFlow更易用,已經(jīng)獲得了巨大的普及。
Keras本身不是框架,而是一個位于其他Deep Learning框架之上的高級API。目前它支持TensorFlow,Theano和CNTK。Keras是迄今為止啟動和運行最快最簡單的框架。定義神經(jīng)網(wǎng)絡(luò)是直觀的,使用功能性API允許人們將層定義為函數(shù)。
而PyTorch像Keras一樣,它也抽象了深度網(wǎng)絡(luò)編程的大部分混亂部分。PyTorch介于Keras和TensorFlow之間,比Keras擁有更靈活、更好的控制力,與此同時用戶又不必做任何瘋狂的聲明式編程。
深度學(xué)習(xí)練習(xí)者整天都在爭論應(yīng)該使用哪個框架。接下來我們將通過4個不同方面,來對比Keras和PyTorch,最終初學(xué)者會明白應(yīng)該選誰。
用于定義模型的類與函數(shù)
Keras提供功能性API來定義深度學(xué)習(xí)模型。神經(jīng)網(wǎng)絡(luò)被定義為一組順序函數(shù),功能定義層1的輸出是功能定義層2的輸入,例如下面demo代碼:
img_input = layers.Input(shape=input_shape)x = layers.Conv2D(64, (3, 3), activation='relu')(img_input)x = layers.Conv2D(64, (3, 3), activation='relu')(x)x = layers.MaxPooling2D((2, 2), strides=(2, 2))(x)
而PyTorch將網(wǎng)絡(luò)設(shè)置為一個類,擴(kuò)展了Torch庫中的torch.nn.Module,PyTorch允許用戶訪問所有Python的類功能而不是簡單的函數(shù)調(diào)用。與Keras類似,PyTorch提供了層作為構(gòu)建塊,但由于它們位于Python類中,因此它們在類的__init __()方法中引用,并由類的forward()方法執(zhí)行。例如下面demo代碼:
class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 64, 3) self.conv2 = nn.Conv2d(64, 64, 3) self.pool = nn.MaxPool2d(2, 2) def forward(self, x): x = F.relu(self.conv1(x)) x = self.pool(F.relu(self.conv2(x))) return xmodel = Net()
所以如果你想更清晰、更優(yōu)雅地定義網(wǎng)絡(luò),可以選擇PyTorch;如果只是求快好上手,可以選擇Keras。
張量、計算圖與標(biāo)準(zhǔn)陣列
Keras API隱藏了編碼器的許多混亂細(xì)節(jié)。定義網(wǎng)絡(luò)層非常直觀,默認(rèn)設(shè)置已經(jīng)足以應(yīng)付大部分情況,不需要涉及到非常底層的內(nèi)容。
而當(dāng)你真正觸達(dá)到更底層的TensorFlow代碼時,同時你也獲得了隨之而來的最具有挑戰(zhàn)性的部分:你需要確保所有矩陣乘法都排成一行。哦對了,甚至別指望打印出圖層的一個輸出,因為你只會在終端上打印出一個漂亮的Tensor定義。
相比起來,PyTorch在這些方面就做的更讓人欣慰一些。你需要知道每個層的輸入和輸出大小,但這很快就能掌握。同時你也不必處理構(gòu)建一個無法在調(diào)試中看到的抽象計算圖。
PyTorch的另一個優(yōu)勢是可以在Torch Tensors和Numpy陣列之間來回切換。而反觀TF,如果需要實現(xiàn)自定義的東西,在TF張量和Numpy陣列之間來回轉(zhuǎn)換可能會很麻煩,需要開發(fā)人員對TensorFlow會話有充分的了解。
PyTorch上這種操作實際上要簡單得多。你只需要知道兩個操作:一個將Torch Tensor(一個Variable對象)切換到Numpy,另一個反過來。
當(dāng)然,如果不需要實現(xiàn)任何花哨的東西,那么Keras會做得很好,因為你不會遇到任何TensorFlow路障。
訓(xùn)練模型
在Keras上訓(xùn)練模型非常容易!一個簡單的.fit()走四方。下面是demo代碼:
history = model.fit_generator( generator=train_generator, epochs=10, validation_data=validation_generator)
但在PyTorch中訓(xùn)練模型就費點事了,包括幾個步驟:
在每批訓(xùn)練開始時初始化梯度
運行正向傳遞模式
運行向后傳遞
計算損失并更新權(quán)重
for epoch in range(2): # loop over the dataset multiple times running_loss = 0.0 for i, data in enumerate(trainloader, 0): # Get the inputs; data is a list of [inputs, labels] inputs, labels = data # (1) Initialise gradients optimizer.zero_grad() # (2) Forward pass outputs = net(inputs) loss = criterion(outputs, labels) # (3) Backward loss.backward() # (4) Compute the loss and update the weights optimizer.step()
你看看,就運行個訓(xùn)練就得這么多步驟!
我想這樣你總能意識到發(fā)生了什么。同時,由于這些模型訓(xùn)練步驟在訓(xùn)練不同模型時基本保持不變,因此非常不必要。
如果安裝了tensorflow-gpu,默認(rèn)情況下在Keras中啟用并完成使用GPU。然后,如果希望將某些操作移動到CPU,則可以使用單行操作。
with tf.device('/cpu:0'): y = apply_non_max_suppression(x)
在PyTorch就得費點勁,你必須為每個Torch張量和numpy變量明確啟用GPU。如果在CPU和GPU之間來回切換以進(jìn)行不同的操作,就會使代碼變得混亂并且容易出錯。
例如,要將我們以前的模型轉(zhuǎn)移到GPU上運行,我們必須執(zhí)行以下操作:
# Get the GPU devicedevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# Transfer the network to GPUnet.to(device)# Transfer the inputs and labels to GPUinputs, labels = data[0].to(device), data[1].to(device)
在GPU這塊,Keras憑借其簡潔和漂亮的默認(rèn)設(shè)置贏得了勝利。
選擇框架的建議
Seif通常給出的建議是從Keras開始,畢竟又快、又簡單、又好用!你甚至可以執(zhí)行自定義圖層和損失函數(shù)的操作,而無需觸及任何一行TensorFlow。
但如果你確實開始深入了解深層網(wǎng)絡(luò)中更細(xì)粒度的方面,或者正在實現(xiàn)非標(biāo)準(zhǔn)的東西,那么PyTorch就是首選庫。
-
深度學(xué)習(xí)
+關(guān)注
關(guān)注
73文章
5554瀏覽量
122482 -
keras
+關(guān)注
關(guān)注
2文章
20瀏覽量
6169 -
pytorch
+關(guān)注
關(guān)注
2文章
809瀏覽量
13766
原文標(biāo)題:深度學(xué)習(xí)框架如何選?4大場景對比Keras和PyTorch
文章出處:【微信號:AI_era,微信公眾號:新智元】歡迎添加關(guān)注!文章轉(zhuǎn)載請注明出處。
發(fā)布評論請先 登錄
esd,mcu和adc復(fù)位問題的必須要注意的四個方面!
印制電路板設(shè)計四個方面的要求
TCO在CRT方面的對比
TCO在LCD方面的對比
TensorFlow、MXNet、CNTK、Theano四個框架對比分析

總結(jié)了區(qū)塊鏈技術(shù)的四個方面來了解區(qū)塊鏈
2018年智能鎖行業(yè)的問題大致總結(jié)為以下四個方面
高頻PCB設(shè)計中,工程師需考慮四個方面帶來的干擾問題并給解決方案
從四個方面解讀PCB射頻電路基礎(chǔ)特性及重要因素
無錫市集成電路產(chǎn)業(yè)四個方面的特點
四個方面區(qū)分MPK和CBB電容資料下載

四個方面看SoC 設(shè)計資料下載

物聯(lián)網(wǎng)技術(shù)在四個方面的應(yīng)用趨勢分析
示波器經(jīng)常說“四個部分”是哪四個部分?

評論