PyTorch 是一個流行的開源機器學習庫,廣泛用于計算機視覺和自然語言處理等領域。它提供了強大的計算圖功能和動態圖特性,使得模型的構建和調試變得更加靈活和直觀。
數據準備
在訓練模型之前,首先需要準備好數據集。PyTorch 提供了 torch.utils.data.Dataset
和 torch.utils.data.DataLoader
兩個類來幫助我們加載和批量處理數據。
1. 定義 Dataset
Dataset
類需要我們實現 __init__
、__len__
和 __getitem__
三個方法。__init__
方法用于初始化數據集,__len__
返回數據集中的樣本數量,__getitem__
根據索引返回單個樣本。
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __len__(self):
return len(self.data)
def __getitem__(self, index):
data = self.data[index]
label = self.labels[index]
return data, label
2. 使用 DataLoader
DataLoader
類用于封裝數據集,并提供批量加載、打亂數據和多線程加載等功能。
from torch.utils.data import DataLoader
dataset = CustomDataset(data, labels)
data_loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
模型定義
在 PyTorch 中,模型是通過繼承 torch.nn.Module
類來定義的。我們需要實現 __init__
方法來定義網絡層,并實現 forward
方法來定義前向傳播。
import torch.nn as nn
import torch.nn.functional as F
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc1 = nn.Linear(784, 128) # 以 MNIST 數據集為例
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
損失函數和優化器
1. 選擇損失函數
PyTorch 提供了多種損失函數,如 nn.CrossEntropyLoss
、nn.MSELoss
等。根據任務的不同,選擇合適的損失函數。
criterion = nn.CrossEntropyLoss()
2. 選擇優化器
PyTorch 也提供了多種優化器,如 torch.optim.SGD
、torch.optim.Adam
等。優化器用于在訓練過程中更新模型的權重。
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
訓練循環
訓練循環是模型訓練的核心,它包括前向傳播、計算損失、反向傳播和權重更新。
model = MyModel()
num_epochs = 10
for epoch in range(num_epochs):
for data, labels in data_loader:
optimizer.zero_grad() # 清空梯度
outputs = model(data) # 前向傳播
loss = criterion(outputs, labels) # 計算損失
loss.backward() # 反向傳播
optimizer.step() # 更新權重
print(f'Epoch {epoch+1}, Loss: {loss.item()}')
模型評估
在訓練過程中,我們還需要定期評估模型的性能,以監控訓練進度和過擬合情況。
def evaluate(model, data_loader):
model.eval() # 設置為評估模式
total = 0
correct = 0
with torch.no_grad(): # 禁用梯度計算
for data, labels in data_loader:
outputs = model(data)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print(f'Accuracy: {accuracy}%')
model.train() # 恢復訓練模式
-
模型
+關注
關注
1文章
3226瀏覽量
48807 -
機器學習
+關注
關注
66文章
8406瀏覽量
132561 -
自然語言處理
+關注
關注
1文章
618瀏覽量
13552 -
pytorch
+關注
關注
2文章
807瀏覽量
13199
發布評論請先 登錄
相關推薦
評論