在前面的理論講解和網(wǎng)絡(luò)實(shí)現(xiàn)中,我們斷斷續(xù)續(xù)的學(xué)習(xí)了 Tensorflow 和 keras 兩個(gè)著名的深度學(xué)習(xí)框架。當(dāng)然主要還是 Tensorflow,keras 的底層計(jì)算都是以 Tensorflow 為后端的。在正式進(jìn)入下一環(huán)節(jié)的學(xué)習(xí)前,筆者先給 pytorch 入個(gè)門,至于系統(tǒng)的學(xué)習(xí),還是需要依靠各種項(xiàng)目實(shí)戰(zhàn)來(lái)鍛煉。
pytorch 是一款可以媲美于 Tensorflow 優(yōu)秀的深度學(xué)習(xí)計(jì)算框架,但又相比于 Tensorflow 在語(yǔ)法上更具備靈活性。pytorch 原生于一款小眾語(yǔ)言 lua,而后基于 python 版本后具備了強(qiáng)大的生命力。作為一款基于 python 的深度學(xué)習(xí)計(jì)算庫(kù),pytorch 提供了高于 numpy 的強(qiáng)大的張量計(jì)算能力和兼具靈活度和速度的深度學(xué)習(xí)研究功能。
下面筆者就以 pytorch 基本張量運(yùn)算、自動(dòng)求導(dǎo)機(jī)制和基于 LeNet-5 的訓(xùn)練實(shí)例對(duì) pytorch 進(jìn)行一個(gè)快速的入門和上手。
和學(xué)習(xí) Tensorflow 中的張量 tensor 一樣,torch 的張量運(yùn)算也可以理解為 numpy 科學(xué)計(jì)算的加強(qiáng)版。底層的計(jì)算邏輯基本一致,torch 張量的強(qiáng)大之處可以利用 GPU 來(lái)加速運(yùn)算。
創(chuàng)建一個(gè) 2x3 的矩陣:
x = torch.Tensor(2, 3)print(x)
獲取矩陣的大?。?/span>
print(x.size())
torch.Size([2, 3])
執(zhí)行張量運(yùn)算:
y = torch.rand(2, 3)print(x + y)
或者是提供一種指定輸出張量的運(yùn)算語(yǔ)法:
result = torch.Tensor(2, 3)torch.add(x, y, out = result)print(result)
當(dāng)然 torch 也可以方便的與 numpy 數(shù)組進(jìn)行轉(zhuǎn)換。
torch 張量轉(zhuǎn)為 numpy 數(shù)組:
a = torch.ones(5).numpy()print(a)
[1. 1. 1. 1. 1.]
numpy 數(shù)組轉(zhuǎn)為 torch 張量:
import numpy as npprint(torch.from_numpy(np.ones(5)))
使用 .cuda 方法將 tensor 在 GPU 上運(yùn)行:
if torch.cuda.is_available(): x = x.cuda() y = y.cuda() x + y
由上述操作可見,torch 的張量運(yùn)算和 numpy 一樣非常簡(jiǎn)單,相較于 tensorflow 的張量運(yùn)算要更加靈活。
在神經(jīng)網(wǎng)絡(luò)的反向傳播中涉及了大量的求導(dǎo)運(yùn)算,pytorch 中求導(dǎo)的核心計(jì)算模塊 autograd 可以幫助我們快速實(shí)現(xiàn)復(fù)雜的求導(dǎo)運(yùn)算。而求導(dǎo)運(yùn)算又是建立在 torch 變量 Variable 基礎(chǔ)之上的,Variable 對(duì) torch 的 Tensor 進(jìn)行了包裝,當(dāng)神經(jīng)網(wǎng)絡(luò)的結(jié)構(gòu)和前向計(jì)算完成后,可以方便的對(duì)變量調(diào)用 backward 方法執(zhí)行反向計(jì)算。
創(chuàng)建一個(gè) Variable:
from torch.autograd import Variablex = Variable(torch.ones(2, 3), requires_grad = True)print(x)
執(zhí)行梯度計(jì)算:
y = x + 2z = y * y * 5out = z.mean()out.backward()print(x.grad)
需要注意的是 torch 計(jì)算梯度時(shí)關(guān)于目標(biāo)變量的梯度計(jì)算的表達(dá)方式為 Variable.grad
。
LeNet-5 網(wǎng)絡(luò)我們?cè)诘?14 講的時(shí)候已經(jīng)對(duì)論文進(jìn)行了詳細(xì)的解讀和實(shí)現(xiàn)。參看第 14 講的鏈接:深度學(xué)習(xí)筆記14:CNN經(jīng)典論文研讀之Le-Net5及其Tensorflow實(shí)現(xiàn)
數(shù)據(jù)準(zhǔn)備和轉(zhuǎn)換:
transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
torchvision 是用來(lái)服務(wù)于 torch 包的,用于生成、轉(zhuǎn)換和準(zhǔn)備預(yù)訓(xùn)練模型。cifar10 數(shù)據(jù)集:
定義 LeNet-5 網(wǎng)絡(luò)結(jié)構(gòu):
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10)
def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x)
return xnet = Net()
定義損失函數(shù)和優(yōu)化器:
import torch.optim as optimcriterion = nn.CrossEntropyLoss()optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
訓(xùn)練 LeNet-5:
for epoch in range(5): running_loss = 0.0 for i, data in enumerate(trainloader, 0):
# get input data inputs, labels = data
# variable the data inputs, labels = Variable(inputs), Variable(labels)
# gradients zeros optimizer.zero_grad()
# forward + backward + optimize outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step()
# print model train info running_loss += loss.data[0]
if i % 2000 == 1999: print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000)) running_loss = 0.0
print('Finished Training')
在測(cè)試集上展示訓(xùn)練效果:
import matplotlib.pyplot as plt
def imshow(img): img = img / 2 + 0.5 npimg = img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0)))
# test the net on test datasets# ground truth
dataiter = iter(testloader)images, labels = dataiter.next()
# print image
imshow(torchvision.utils.make_grid(images))print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))
看看 LeNet-5 的預(yù)測(cè)結(jié)果:
# the net predict result
outputs = net(Variable(images))_, predicted = torch.max(outputs.data, 1)print('Predicted: ', ' '.join('%5s' % classes[predicted[j]] for j in range(4)))
(‘Predicted: ‘, ‘ cat ship ship plane’)
貌似訓(xùn)練效果很好。再來(lái)看一下模型在全部測(cè)試集上的表現(xiàn):
# test on all test data
correct = 0
total = 0
for data in testloader: images, labels = data outputs = net(Variable(images)) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum()print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))
Accuracy of the network on the 10000 test images: 61 %
準(zhǔn)確率達(dá)到 61%,已經(jīng)遠(yuǎn)超隨機(jī)猜測(cè)的10%的準(zhǔn)確率了。
再看看模型在每一類別上的分類準(zhǔn)確率的表現(xiàn):
class_correct = list(0. for i in range(10))class_total = list(0. for i in range(10))
for data in testloader: images, labels = data outputs = net(Variable(images)) _, predicted = torch.max(outputs.data, 1) c = (predicted == labels).squeeze()
for i in range(4): label = labels[i] class_correct[label] += c[i] class_total[label] += 1
for i in range(10): print('Accuracy of %5s : %2d %%' % ( classes[i], 100 * class_correct[i] / class_total[i]))
可見模型在貓和鳥等小型動(dòng)物上分類效果較差,在車船飛機(jī)等大型物體上效果較好。該實(shí)例來(lái)自于 pytorch 的官方 tutorial 60 分鐘快速上手文檔,能夠讓大家非??焖俚娜腴T學(xué)習(xí) pytorch 。 在對(duì)神經(jīng)網(wǎng)絡(luò)的基本原理有深刻理解的基礎(chǔ)上,對(duì)比之前的 tensorflow 和 keras,相信大家都能快速掌握 pytorch 。
參考資料:
http://pytorch.org/tutorials/
http://pytorch.apachecn.org/cn/tutorials/
聯(lián)系客服