好湿?好紧?好多水好爽自慰,久久久噜久噜久久综合,成人做爰A片免费看黄冈,机机对机机30分钟无遮挡

主頁 > 知識庫 > 超詳細PyTorch實現手寫數字識別器的示例代碼

超詳細PyTorch實現手寫數字識別器的示例代碼

熱門標簽:原裝電話機器人 西藏智能外呼系統五星服務 千陽自動外呼系統 平頂山外呼系統免費 清遠360地圖標注方法 400電話申請服務商選什么 工廠智能電話機器人 在哪里辦理400電話號碼 江蘇客服外呼系統廠家

前言

深度學習中有很多玩具數據,mnist就是其中一個,一個人能否入門深度學習往往就是以能否玩轉mnist數據來判斷的,在前面很多基礎介紹后我們就可以來實現一個簡單的手寫數字識別的網絡了

數據的處理

我們使用pytorch自帶的包進行數據的預處理

import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt

transform = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize((0.5), (0.5))
])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True,num_workers=2)

注釋:transforms.Normalize用于數據的標準化,具體實現
mean:均值 總和后除個數
std:方差 每個元素減去均值再平方再除個數

norm_data = (tensor - mean) / std

這里就直接將圖片標準化到了-1到1的范圍,標準化的原因就是因為如果某個數在數據中很大很大,就導致其權重較大,從而影響到其他數據,而本身我們的數據都是平等的,所以標準化后將數據分布到-1到1的范圍,使得所有數據都不會有太大的權重導致網絡出現巨大的波動
trainloader現在是一個可迭代的對象,那么我們可以使用for循環進行遍歷了,由于是使用yield返回的數據,為了節約內存

觀察一下數據

def imshow(img):
   img = img / 2 + 0.5 # unnormalize
   npimg = img.numpy()
   plt.imshow(np.transpose(npimg, (1, 2, 0)))
   plt.show()
# torchvision.utils.make_grid 將圖片進行拼接
imshow(torchvision.utils.make_grid(iter(trainloader).next()[0]))

構建網絡

from torch import nn
import torch.nn.functional as F
class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(in_channels=1, out_channels=28, kernel_size=5) # 14
    self.pool = nn.MaxPool2d(kernel_size=2, stride=2) # 無參數學習因此無需設置兩個
    self.conv2 = nn.Conv2d(in_channels=28, out_channels=28*2, kernel_size=5) # 7
    self.fc1 = nn.Linear(in_features=28*2*4*4, out_features=1024)
    self.fc2 = nn.Linear(in_features=1024, out_features=10)
  def forward(self, inputs):
    x = self.pool(F.relu(self.conv1(inputs)))
    x = self.pool(F.relu(self.conv2(x)))
    x = x.view(inputs.size()[0],-1)
    x = F.relu(self.fc1(x))
    return self.fc2(x)

下面是卷積的動態演示

in_channels:為輸入通道數 彩色圖片有3個通道 黑白有1個通道
out_channels:輸出通道數
kernel_size:卷積核的大小
stride:卷積的步長
padding:外邊距大小

輸出的size計算公式

  • h = (h - kernel_size + 2*padding)/stride + 1
  • w = (w - kernel_size + 2*padding)/stride + 1

MaxPool2d:是沒有參數進行運算的

實例化網絡優化器,并且使用GPU進行訓練

net = Net()
opt = torch.optim.Adam(params=net.parameters(), lr=0.001)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)
Net(
 (conv1): Conv2d(1, 28, kernel_size=(5, 5), stride=(1, 1))
 (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
 (conv2): Conv2d(28, 56, kernel_size=(5, 5), stride=(1, 1))
 (fc1): Linear(in_features=896, out_features=1024, bias=True)
 (fc2): Linear(in_features=1024, out_features=10, bias=True)
)

訓練主要代碼

for epoch in range(50):
  for images, labels in trainloader:
    images = images.to(device)
    labels = labels.to(device)
    pre_label = net(images)
    loss = F.cross_entropy(input=pre_label, target=labels).mean()
    pre_label = torch.argmax(pre_label, dim=1)
    acc = (pre_label==labels).sum()/torch.tensor(labels.size()[0], dtype=torch.float32)
    net.zero_grad()
    loss.backward()
    opt.step()
  print(acc.detach().cpu().numpy(), loss.detach().cpu().numpy())

F.cross_entropy交叉熵函數


源碼中已經幫助我們實現了softmax因此不需要自己進行softmax操作了
torch.argmax計算最大數所在索引值

acc = (pre_label==labels).sum()/torch.tensor(labels.size()[0], dtype=torch.float32)
# pre_label==labels 相同維度進行比較相同返回True不同的返回False,True為1 False為0, 即可獲取到相等的個數,再除總個數,就得到了Accuracy準確度了

預測

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=True,num_workers=2)
images, labels = iter(testloader).next()
images = images.to(device)
labels = labels.to(device)
with torch.no_grad():
  pre_label = net(images)
  pre_label = torch.argmax(pre_label, dim=1)
  acc = (pre_label==labels).sum()/torch.tensor(labels.size()[0], dtype=torch.float32)
  print(acc)

總結

本節我們了解了標準化數據·卷積的原理、簡答的構建了一個網絡,并讓它去識別手寫體,也是對前面章節的總匯了

到此這篇關于超詳細PyTorch實現手寫數字識別器的示例代碼的文章就介紹到這了,更多相關PyTorch 手寫數字識別器內容請搜索腳本之家以前的文章或繼續瀏覽下面的相關文章希望大家以后多多支持腳本之家!

您可能感興趣的文章:
  • pytorch 利用lstm做mnist手寫數字識別分類的實例
  • 詳解PyTorch手寫數字識別(MNIST數據集)
  • PyTorch CNN實戰之MNIST手寫數字識別示例
  • Pytorch實現圖像識別之數字識別(附詳細注釋)

標簽:天水 錦州 隨州 股票 日照 白城 西安 安慶

巨人網絡通訊聲明:本文標題《超詳細PyTorch實現手寫數字識別器的示例代碼》,本文關鍵詞  超,詳細,PyTorch,實現,手寫,;如發現本文內容存在版權問題,煩請提供相關信息告之我們,我們將及時溝通與處理。本站內容系統采集于網絡,涉及言論、版權與本站無關。
  • 相關文章
  • 下面列出與本文章《超詳細PyTorch實現手寫數字識別器的示例代碼》相關的同類信息!
  • 本頁收集關于超詳細PyTorch實現手寫數字識別器的示例代碼的相關信息資訊供網民參考!
  • 推薦文章