本文将演示石头剪刀布图片库的神经网络训练过程。石头剪刀布数据集包含了不同的手势图片,来自不同的种族、年龄和性别。
首先下载石头剪刀布的训练集和测试集:

import ssl
from pathlib import Path
from urllib.error import URLError
from urllib.request import urlopen

DOWNLOAD_DIR = Path("D:/mldownload")
DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)

RPS_URL = "https://storage.googleapis.com/learning-datasets/rps.zip"
RPS_TEST_URL = "https://storage.googleapis.com/learning-datasets/rps-test-set.zip"
RPS_ZIP = DOWNLOAD_DIR / "rps.zip"
RPS_TEST_ZIP = DOWNLOAD_DIR / "rps-test-set.zip"

def download_file(url, destination):
    if destination.exists() and destination.stat().st_size > 0:
        print(f"File already exists, skipping: {destination}")
        return

    temp_path = destination.with_suffix(destination.suffix + ".part")
    print(f"Downloading: {url}")

    try:
        response = urlopen(url, timeout=120)
    except URLError:
        context = ssl._create_unverified_context()
        response = urlopen(url, timeout=120, context=context)

    with response, temp_path.open("wb") as file:
        while True:
            data = response.read(1024 * 1024)
            if not data:
                break
            file.write(data)

    temp_path.replace(destination)
    size_mb = destination.stat().st_size / 1024 / 1024
    print(f"Downloaded: {destination} ({size_mb:.1f} MB)")

download_file(RPS_URL, RPS_ZIP)
download_file(RPS_TEST_URL, RPS_TEST_ZIP)

注意:根据自己的实际情况设定下载目录。若上述代码无法下载数据集,尝试使用浏览器手动下载!
然后解压下载的数据集。

import zipfile

def extract_zip(zip_path, extract_dir):
    if not zip_path.exists():
        raise FileNotFoundError(f"Zip file not found. Run the download cell first: {zip_path}")

    with zipfile.ZipFile(zip_path, "r") as zip_ref:
        bad_file = zip_ref.testzip()
        if bad_file is not None:
            raise zipfile.BadZipFile(
                f"Zip file looks corrupted at {bad_file}. Delete {zip_path} and download again."
            )
        zip_ref.extractall(extract_dir)

    print(f"Extracted: {zip_path} -> {extract_dir}")

extract_zip(RPS_ZIP, DOWNLOAD_DIR)
extract_zip(RPS_TEST_ZIP, DOWNLOAD_DIR)

检测数据集的解压结果,打印相关信息。

rock_dir = DOWNLOAD_DIR / "rps" / "rock"
paper_dir = DOWNLOAD_DIR / "rps" / "paper"
scissors_dir = DOWNLOAD_DIR / "rps" / "scissors"

for image_dir in [rock_dir, paper_dir, scissors_dir]:
    if not image_dir.exists():
        raise FileNotFoundError(
            f"Directory not found: {image_dir}. Run the download and extract cells first."
        )

rock_files = sorted(path.name for path in rock_dir.iterdir())
paper_files = sorted(path.name for path in paper_dir.iterdir())
scissors_files = sorted(path.name for path in scissors_dir.iterdir())

print("total training rock images:", len(rock_files))
print("total training paper images:", len(paper_files))
print("total training scissors images:", len(scissors_files))

print(rock_files[:10])
print(paper_files[:10])
print(scissors_files[:10])
total training rock images: 840
total training paper images: 840
total training scissors images: 840
['rock01-000.png', 'rock01-001.png', 'rock01-002.png', 'rock01-003.png', 'rock01-004.png', 'rock01-005.png', 'rock01-006.png', 'rock01-007.png', 'rock01-008.png', 'rock01-009.png']
['paper01-000.png', 'paper01-001.png', 'paper01-002.png', 'paper01-003.png', 'paper01-004.png', 'paper01-005.png', 'paper01-006.png', 'paper01-007.png', 'paper01-008.png', 'paper01-009.png']
['scissors01-000.png', 'scissors01-001.png', 'scissors01-002.png', 'scissors01-003.png', 'scissors01-004.png', 'scissors01-005.png', 'scissors01-006.png', 'scissors01-007.png', 'scissors01-008.png', 'scissors01-009.png']

各打印两张石头剪刀布训练集图片

%matplotlib inline

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

pic_index = 2

next_rock = [rock_dir / fname for fname in rock_files[pic_index - 2:pic_index]]
next_paper = [paper_dir / fname for fname in paper_files[pic_index - 2:pic_index]]
next_scissors = [scissors_dir / fname for fname in scissors_files[pic_index - 2:pic_index]]

for img_path in next_rock + next_paper + next_scissors:
    img = mpimg.imread(img_path)
    plt.imshow(img)
    plt.axis("off")
    plt.show()

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

调用TensorFlow的keras进行数据模型的训练和评估。Keras是开源人工神经网络库,TensorFlow集成了keras的调用接口,可以方便的使用。

import tensorflow as tf
import keras_preprocessing
from keras_preprocessing import image
from keras_preprocessing.image import ImageDataGenerator

TRAINING_DIR = "D:/mldownload/rps/"
training_datagen = ImageDataGenerator(
      rescale = 1./255,
	    rotation_range=40,
      width_shift_range=0.2,
      height_shift_range=0.2,
      shear_range=0.2,
      zoom_range=0.2,
      horizontal_flip=True,
      fill_mode='nearest')

VALIDATION_DIR = "D:/mldownload/rps-test-set/"
validation_datagen = ImageDataGenerator(rescale = 1./255)

train_generator = training_datagen.flow_from_directory(
	TRAINING_DIR,
	target_size=(150,150),
	class_mode='categorical',
  batch_size=126
)

validation_generator = validation_datagen.flow_from_directory(
	VALIDATION_DIR,
	target_size=(150,150),
	class_mode='categorical',
  batch_size=126
)

model = tf.keras.models.Sequential([
    # Note the input shape is the desired size of the image 150x150 with 3 bytes color
    # This is the first convolution
    tf.keras.layers.Conv2D(64, (3,3), activation='relu', input_shape=(150, 150, 3)),
    tf.keras.layers.MaxPooling2D(2, 2),
    # The second convolution
    tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2,2),
    # The third convolution
    tf.keras.layers.Conv2D(128, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2,2),
    # The fourth convolution
    tf.keras.layers.Conv2D(128, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2,2),
    # Flatten the results to feed into a DNN
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.5),
    # 512 neuron hidden layer
    tf.keras.layers.Dense(512, activation='relu'),
    tf.keras.layers.Dense(3, activation='softmax')
])


model.summary()

model.compile(loss = 'categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])

history = model.fit(train_generator, epochs=25, steps_per_epoch=20, validation_data = validation_generator, verbose = 1, validation_steps=3)

model.save("rps.h5")
Found 2520 images belonging to 3 classes.
Found 372 images belonging to 3 classes.
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d (Conv2D)             (None, 148, 148, 64)      1792      
                                                                 
 max_pooling2d (MaxPooling2D  (None, 74, 74, 64)       0         
 )                                                               
                                                                 
 conv2d_1 (Conv2D)           (None, 72, 72, 64)        36928     
                                                                 
 max_pooling2d_1 (MaxPooling  (None, 36, 36, 64)       0         
 2D)                                                             
                                                                 
 conv2d_2 (Conv2D)           (None, 34, 34, 128)       73856     
                                                                 
 max_pooling2d_2 (MaxPooling  (None, 17, 17, 128)      0         
 2D)                                                             
                                                                 
 conv2d_3 (Conv2D)           (None, 15, 15, 128)       147584    
                                                                 
 max_pooling2d_3 (MaxPooling  (None, 7, 7, 128)        0         
 2D)                                                             
                                                                 
 flatten (Flatten)           (None, 6272)              0         
                                                                 
 dropout (Dropout)           (None, 6272)              0         
                                                                 
 dense (Dense)               (None, 512)               3211776   
                                                                 
 dense_1 (Dense)             (None, 3)                 1539      
                                                                 
=================================================================
Total params: 3,473,475
Trainable params: 3,473,475
Non-trainable params: 0
_________________________________________________________________
Epoch 1/25
20/20 [==============================] - 53s 3s/step - loss: 1.4751 - accuracy: 0.3599 - val_loss: 1.1369 - val_accuracy: 0.3333
Epoch 2/25
20/20 [==============================] - 49s 2s/step - loss: 1.1303 - accuracy: 0.3603 - val_loss: 1.0965 - val_accuracy: 0.5108
Epoch 3/25
20/20 [==============================] - 47s 2s/step - loss: 1.0863 - accuracy: 0.4032 - val_loss: 0.9790 - val_accuracy: 0.3978
Epoch 4/25
20/20 [==============================] - 50s 2s/step - loss: 1.0418 - accuracy: 0.5139 - val_loss: 0.8253 - val_accuracy: 0.7258
Epoch 5/25
20/20 [==============================] - 50s 2s/step - loss: 0.8743 - accuracy: 0.6087 - val_loss: 0.4759 - val_accuracy: 0.9651
Epoch 6/25
20/20 [==============================] - 48s 2s/step - loss: 0.8080 - accuracy: 0.6345 - val_loss: 0.6926 - val_accuracy: 0.6183
Epoch 7/25
20/20 [==============================] - 46s 2s/step - loss: 0.6538 - accuracy: 0.7103 - val_loss: 0.2193 - val_accuracy: 0.9785
Epoch 8/25
20/20 [==============================] - 46s 2s/step - loss: 0.5827 - accuracy: 0.7579 - val_loss: 0.2920 - val_accuracy: 0.9731
Epoch 9/25
20/20 [==============================] - 45s 2s/step - loss: 0.4396 - accuracy: 0.8286 - val_loss: 0.0803 - val_accuracy: 1.0000
Epoch 10/25
20/20 [==============================] - 47s 2s/step - loss: 0.3461 - accuracy: 0.8560 - val_loss: 0.3216 - val_accuracy: 0.7634
Epoch 11/25
20/20 [==============================] - 45s 2s/step - loss: 0.3198 - accuracy: 0.8730 - val_loss: 0.0706 - val_accuracy: 0.9651
Epoch 12/25
20/20 [==============================] - 45s 2s/step - loss: 0.2977 - accuracy: 0.8861 - val_loss: 0.0884 - val_accuracy: 0.9651
Epoch 13/25
20/20 [==============================] - 47s 2s/step - loss: 0.2832 - accuracy: 0.8952 - val_loss: 0.0391 - val_accuracy: 0.9839
Epoch 14/25
20/20 [==============================] - 43s 2s/step - loss: 0.1713 - accuracy: 0.9353 - val_loss: 0.0592 - val_accuracy: 0.9758
Epoch 15/25
20/20 [==============================] - 44s 2s/step - loss: 0.2972 - accuracy: 0.8913 - val_loss: 0.1070 - val_accuracy: 0.9839
Epoch 16/25
20/20 [==============================] - 48s 2s/step - loss: 0.1306 - accuracy: 0.9575 - val_loss: 0.0549 - val_accuracy: 0.9785
Epoch 17/25
20/20 [==============================] - 46s 2s/step - loss: 0.1886 - accuracy: 0.9226 - val_loss: 0.0500 - val_accuracy: 0.9866
Epoch 18/25
20/20 [==============================] - 45s 2s/step - loss: 0.1101 - accuracy: 0.9615 - val_loss: 0.0518 - val_accuracy: 0.9651
Epoch 19/25
20/20 [==============================] - 48s 2s/step - loss: 0.1343 - accuracy: 0.9556 - val_loss: 0.0105 - val_accuracy: 1.0000
Epoch 20/25
20/20 [==============================] - 45s 2s/step - loss: 0.1349 - accuracy: 0.9528 - val_loss: 0.2117 - val_accuracy: 0.8952
Epoch 21/25
20/20 [==============================] - 50s 2s/step - loss: 0.0918 - accuracy: 0.9687 - val_loss: 0.0844 - val_accuracy: 0.9651
Epoch 22/25
20/20 [==============================] - 50s 2s/step - loss: 0.2075 - accuracy: 0.9317 - val_loss: 0.0403 - val_accuracy: 0.9839
Epoch 23/25
20/20 [==============================] - 51s 3s/step - loss: 0.0937 - accuracy: 0.9675 - val_loss: 0.7548 - val_accuracy: 0.7070
Epoch 24/25
20/20 [==============================] - 46s 2s/step - loss: 0.0861 - accuracy: 0.9694 - val_loss: 0.1306 - val_accuracy: 0.9435
Epoch 25/25
20/20 [==============================] - 47s 2s/step - loss: 0.1002 - accuracy: 0.9655 - val_loss: 0.0382 - val_accuracy: 0.9839

ImageDataGenerator是Keras中图像预处理的类,经过预处理使得后续的训练更加准确。

Sequential定义了序列化的神经网络,封装了神经网络的结构,有一组输入和一组输出。可以定义多个神经层,各层之间按照先后顺序堆叠,前一层的输出就是后一层的输入,通过多个层的堆叠,构建出神经网络。

神经网络两个常用的操作:卷积和池化。由于图片中可能包含干扰或者弱信息,使用卷积处理(此处的Conv2D函数)使得我们能够找到特定的局部图像特征(如边缘)。此处使用了3X3的滤波器(通常称为垂直索伯滤波器)。而池化(此处的MaxPooling2D)的作用是降低采样,因为卷积层输出中包含很多冗余信息。池化通过减小输入的大小降低输出值的数量。详细的信息可以参考知乎回答“如何理解卷积神经网络(CNN)中的卷积和池化?”。更多的卷积算法参考Github Convolution arithmetic

Dense的操作即全连接层操作,本质就是由一个特征空间线性变换到另一个特征空间。Dense层的目的是将前面提取的特征,在dense经过非线性变化,提取这些特征之间的关联,最后映射到输出空间上。Dense这里作为输出层。

完成模型训练之后,我们绘制训练和验证结果的相关信息。

import matplotlib.pyplot as plt
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(len(acc))

plt.plot(epochs, acc, 'r', label='Training accuracy')
plt.plot(epochs, val_acc, 'b', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.legend(loc=0)
plt.figure()
plt.show()

在这里插入图片描述
<Figure size 640x480 with 0 Axes>

利用生成了模型,我们可以运行实际中的例子,例如上传石头剪头布的图片进行推测,使用model.predict。这里不做展开,后续我们利用Tensorflow Lite进行Android APP开发时,可以很自然的利用手机自带的摄像头或者图片库进行图片输入。

参考文献

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐