使用python對數據集進行批處理

_jym 2022-01-07 14:45:20 阅读数:267

使用 python

【機器學習】使用python對數據集進行批處理

只輸入一張圖像數據過程和一次性處理100張圖像數據過程中,數組形狀變換如下圖所示:

在這裏插入圖片描述

這些數組形狀可以在代碼中輸出出來:

def get_data():
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)
return x_test, t_test
def init_network():
with open("sample_weight.pkl", 'rb') as f:
network = pickle.load(f)
return network
x, t = get_data()
network = init_network()
print(x.shape)
print(x[0].shape)
W1, W2, W3 = network['W1'], network['W2'], network['W3']
print(W1.shape)
print(W2.shape)
print(W3.shape)

輸出結果:

(10000, 784)
(784,)
(784, 50)
(50, 100)
(100, 10)

基於批處理的代碼實現:

batch_size=100

for i in range(0, len(x), batch_size):這句話的意義,使i從0開始每次增加100 。

x_batch = x[i:i+batch_size]可以取出第i個到第i+100個之間的數據。

這樣數據就變成了x[0:100]、x[100:200]、…這樣的批數據。

p = np.argmax(y_batch, axis=1),這句話獲取y_batch取最大值時的y_batch數組的下標。 axis=1錶示從行方向找最大值。也就是說,輸入一個圖片,輸出一個y,找0-9下標裏面y最大的那個下標,就是神經網絡根據這個圖片猜出來的數字。

import sys, os
sys.path.append(os.pardir) # 為了導入父目錄的文件而進行的設定
import numpy as np
import pickle
from dataset.mnist import load_mnist
from common.functions import sigmoid, softmax
def get_data():
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)
return x_test, t_test
def init_network():
with open("sample_weight.pkl", 'rb') as f:
network = pickle.load(f)
return network
def predict(network, x):
w1, w2, w3 = network['W1'], network['W2'], network['W3']
b1, b2, b3 = network['b1'], network['b2'], network['b3']
a1 = np.dot(x, w1) + b1
z1 = sigmoid(a1)
a2 = np.dot(z1, w2) + b2
z2 = sigmoid(a2)
a3 = np.dot(z2, w3) + b3
y = softmax(a3)
return y
x, t = get_data()
network = init_network()
print(x.shape)
print(x[0].shape)
W1, W2, W3 = network['W1'], network['W2'], network['W3']
print(W1.shape)
print(W2.shape)
print(W3.shape)
batch_size = 100 # 批數量
accuracy_cnt = 0
for i in range(0, len(x), batch_size):
x_batch = x[i:i+batch_size]
y_batch = predict(network, x_batch)
p = np.argmax(y_batch, axis=1)
accuracy_cnt += np.sum(p == t[i:i+batch_size])
print("Accuracy:" + str(float(accuracy_cnt) / len(x)))
版权声明:本文为[_jym]所创,转载请带上原文链接,感谢。 https://gsmany.com/2022/01/202201071445204199.html