TensorFlow训练石头剪刀布数据集
本文将演示石头剪刀布图片库的神经网络训练过程。石头剪刀布数据集包含了不同的手势图片,来自不同的种族、年龄和性别。
首先下载石头剪刀布的训练集和测试集:
import ssl
from pathlib import Path
from urllib.error import URLError
from urllib.request import urlopen
DOWNLOAD_DIR = Path("D:/mldownload")
DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
RPS_URL = "https://storage.googleapis.com/learning-datasets/rps.zip"
RPS_TEST_URL = "https://storage.googleapis.com/learning-datasets/rps-test-set.zip"
RPS_ZIP = DOWNLOAD_DIR / "rps.zip"
RPS_TEST_ZIP = DOWNLOAD_DIR / "rps-test-set.zip"
def download_file(url, destination):
if destination.exists() and destination.stat().st_size > 0:
print(f"File already exists, skipping: {destination}")
return
temp_path = destination.with_suffix(destination.suffix + ".part")
print(f"Downloading: {url}")
try:
response = urlopen(url, timeout=120)
except URLError:
context = ssl._create_unverified_context()
response = urlopen(url, timeout=120, context=context)
with response, temp_path.open("wb") as file:
while True:
data = response.read(1024 * 1024)
if not data:
break
file.write(data)
temp_path.replace(destination)
size_mb = destination.stat().st_size / 1024 / 1024
print(f"Downloaded: {destination} ({size_mb:.1f} MB)")
download_file(RPS_URL, RPS_ZIP)
download_file(RPS_TEST_URL, RPS_TEST_ZIP)
注意:根据自己的实际情况设定下载目录。若上述代码无法下载数据集,尝试使用浏览器手动下载!
然后解压下载的数据集。
import zipfile
def extract_zip(zip_path, extract_dir):
if not zip_path.exists():
raise FileNotFoundError(f"Zip file not found. Run the download cell first: {zip_path}")
with zipfile.ZipFile(zip_path, "r") as zip_ref:
bad_file = zip_ref.testzip()
if bad_file is not None:
raise zipfile.BadZipFile(
f"Zip file looks corrupted at {bad_file}. Delete {zip_path} and download again."
)
zip_ref.extractall(extract_dir)
print(f"Extracted: {zip_path} -> {extract_dir}")
extract_zip(RPS_ZIP, DOWNLOAD_DIR)
extract_zip(RPS_TEST_ZIP, DOWNLOAD_DIR)
检测数据集的解压结果,打印相关信息。
rock_dir = DOWNLOAD_DIR / "rps" / "rock"
paper_dir = DOWNLOAD_DIR / "rps" / "paper"
scissors_dir = DOWNLOAD_DIR / "rps" / "scissors"
for image_dir in [rock_dir, paper_dir, scissors_dir]:
if not image_dir.exists():
raise FileNotFoundError(
f"Directory not found: {image_dir}. Run the download and extract cells first."
)
rock_files = sorted(path.name for path in rock_dir.iterdir())
paper_files = sorted(path.name for path in paper_dir.iterdir())
scissors_files = sorted(path.name for path in scissors_dir.iterdir())
print("total training rock images:", len(rock_files))
print("total training paper images:", len(paper_files))
print("total training scissors images:", len(scissors_files))
print(rock_files[:10])
print(paper_files[:10])
print(scissors_files[:10])
total training rock images: 840
total training paper images: 840
total training scissors images: 840
['rock01-000.png', 'rock01-001.png', 'rock01-002.png', 'rock01-003.png', 'rock01-004.png', 'rock01-005.png', 'rock01-006.png', 'rock01-007.png', 'rock01-008.png', 'rock01-009.png']
['paper01-000.png', 'paper01-001.png', 'paper01-002.png', 'paper01-003.png', 'paper01-004.png', 'paper01-005.png', 'paper01-006.png', 'paper01-007.png', 'paper01-008.png', 'paper01-009.png']
['scissors01-000.png', 'scissors01-001.png', 'scissors01-002.png', 'scissors01-003.png', 'scissors01-004.png', 'scissors01-005.png', 'scissors01-006.png', 'scissors01-007.png', 'scissors01-008.png', 'scissors01-009.png']
各打印两张石头剪刀布训练集图片
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
pic_index = 2
next_rock = [rock_dir / fname for fname in rock_files[pic_index - 2:pic_index]]
next_paper = [paper_dir / fname for fname in paper_files[pic_index - 2:pic_index]]
next_scissors = [scissors_dir / fname for fname in scissors_files[pic_index - 2:pic_index]]
for img_path in next_rock + next_paper + next_scissors:
img = mpimg.imread(img_path)
plt.imshow(img)
plt.axis("off")
plt.show()






调用TensorFlow的keras进行数据模型的训练和评估。Keras是开源人工神经网络库,TensorFlow集成了keras的调用接口,可以方便的使用。
import tensorflow as tf
import keras_preprocessing
from keras_preprocessing import image
from keras_preprocessing.image import ImageDataGenerator
TRAINING_DIR = "D:/mldownload/rps/"
training_datagen = ImageDataGenerator(
rescale = 1./255,
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest')
VALIDATION_DIR = "D:/mldownload/rps-test-set/"
validation_datagen = ImageDataGenerator(rescale = 1./255)
train_generator = training_datagen.flow_from_directory(
TRAINING_DIR,
target_size=(150,150),
class_mode='categorical',
batch_size=126
)
validation_generator = validation_datagen.flow_from_directory(
VALIDATION_DIR,
target_size=(150,150),
class_mode='categorical',
batch_size=126
)
model = tf.keras.models.Sequential([
# Note the input shape is the desired size of the image 150x150 with 3 bytes color
# This is the first convolution
tf.keras.layers.Conv2D(64, (3,3), activation='relu', input_shape=(150, 150, 3)),
tf.keras.layers.MaxPooling2D(2, 2),
# The second convolution
tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
tf.keras.layers.MaxPooling2D(2,2),
# The third convolution
tf.keras.layers.Conv2D(128, (3,3), activation='relu'),
tf.keras.layers.MaxPooling2D(2,2),
# The fourth convolution
tf.keras.layers.Conv2D(128, (3,3), activation='relu'),
tf.keras.layers.MaxPooling2D(2,2),
# Flatten the results to feed into a DNN
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.5),
# 512 neuron hidden layer
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dense(3, activation='softmax')
])
model.summary()
model.compile(loss = 'categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
history = model.fit(train_generator, epochs=25, steps_per_epoch=20, validation_data = validation_generator, verbose = 1, validation_steps=3)
model.save("rps.h5")
Found 2520 images belonging to 3 classes.
Found 372 images belonging to 3 classes.
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 148, 148, 64) 1792
max_pooling2d (MaxPooling2D (None, 74, 74, 64) 0
)
conv2d_1 (Conv2D) (None, 72, 72, 64) 36928
max_pooling2d_1 (MaxPooling (None, 36, 36, 64) 0
2D)
conv2d_2 (Conv2D) (None, 34, 34, 128) 73856
max_pooling2d_2 (MaxPooling (None, 17, 17, 128) 0
2D)
conv2d_3 (Conv2D) (None, 15, 15, 128) 147584
max_pooling2d_3 (MaxPooling (None, 7, 7, 128) 0
2D)
flatten (Flatten) (None, 6272) 0
dropout (Dropout) (None, 6272) 0
dense (Dense) (None, 512) 3211776
dense_1 (Dense) (None, 3) 1539
=================================================================
Total params: 3,473,475
Trainable params: 3,473,475
Non-trainable params: 0
_________________________________________________________________
Epoch 1/25
20/20 [==============================] - 53s 3s/step - loss: 1.4751 - accuracy: 0.3599 - val_loss: 1.1369 - val_accuracy: 0.3333
Epoch 2/25
20/20 [==============================] - 49s 2s/step - loss: 1.1303 - accuracy: 0.3603 - val_loss: 1.0965 - val_accuracy: 0.5108
Epoch 3/25
20/20 [==============================] - 47s 2s/step - loss: 1.0863 - accuracy: 0.4032 - val_loss: 0.9790 - val_accuracy: 0.3978
Epoch 4/25
20/20 [==============================] - 50s 2s/step - loss: 1.0418 - accuracy: 0.5139 - val_loss: 0.8253 - val_accuracy: 0.7258
Epoch 5/25
20/20 [==============================] - 50s 2s/step - loss: 0.8743 - accuracy: 0.6087 - val_loss: 0.4759 - val_accuracy: 0.9651
Epoch 6/25
20/20 [==============================] - 48s 2s/step - loss: 0.8080 - accuracy: 0.6345 - val_loss: 0.6926 - val_accuracy: 0.6183
Epoch 7/25
20/20 [==============================] - 46s 2s/step - loss: 0.6538 - accuracy: 0.7103 - val_loss: 0.2193 - val_accuracy: 0.9785
Epoch 8/25
20/20 [==============================] - 46s 2s/step - loss: 0.5827 - accuracy: 0.7579 - val_loss: 0.2920 - val_accuracy: 0.9731
Epoch 9/25
20/20 [==============================] - 45s 2s/step - loss: 0.4396 - accuracy: 0.8286 - val_loss: 0.0803 - val_accuracy: 1.0000
Epoch 10/25
20/20 [==============================] - 47s 2s/step - loss: 0.3461 - accuracy: 0.8560 - val_loss: 0.3216 - val_accuracy: 0.7634
Epoch 11/25
20/20 [==============================] - 45s 2s/step - loss: 0.3198 - accuracy: 0.8730 - val_loss: 0.0706 - val_accuracy: 0.9651
Epoch 12/25
20/20 [==============================] - 45s 2s/step - loss: 0.2977 - accuracy: 0.8861 - val_loss: 0.0884 - val_accuracy: 0.9651
Epoch 13/25
20/20 [==============================] - 47s 2s/step - loss: 0.2832 - accuracy: 0.8952 - val_loss: 0.0391 - val_accuracy: 0.9839
Epoch 14/25
20/20 [==============================] - 43s 2s/step - loss: 0.1713 - accuracy: 0.9353 - val_loss: 0.0592 - val_accuracy: 0.9758
Epoch 15/25
20/20 [==============================] - 44s 2s/step - loss: 0.2972 - accuracy: 0.8913 - val_loss: 0.1070 - val_accuracy: 0.9839
Epoch 16/25
20/20 [==============================] - 48s 2s/step - loss: 0.1306 - accuracy: 0.9575 - val_loss: 0.0549 - val_accuracy: 0.9785
Epoch 17/25
20/20 [==============================] - 46s 2s/step - loss: 0.1886 - accuracy: 0.9226 - val_loss: 0.0500 - val_accuracy: 0.9866
Epoch 18/25
20/20 [==============================] - 45s 2s/step - loss: 0.1101 - accuracy: 0.9615 - val_loss: 0.0518 - val_accuracy: 0.9651
Epoch 19/25
20/20 [==============================] - 48s 2s/step - loss: 0.1343 - accuracy: 0.9556 - val_loss: 0.0105 - val_accuracy: 1.0000
Epoch 20/25
20/20 [==============================] - 45s 2s/step - loss: 0.1349 - accuracy: 0.9528 - val_loss: 0.2117 - val_accuracy: 0.8952
Epoch 21/25
20/20 [==============================] - 50s 2s/step - loss: 0.0918 - accuracy: 0.9687 - val_loss: 0.0844 - val_accuracy: 0.9651
Epoch 22/25
20/20 [==============================] - 50s 2s/step - loss: 0.2075 - accuracy: 0.9317 - val_loss: 0.0403 - val_accuracy: 0.9839
Epoch 23/25
20/20 [==============================] - 51s 3s/step - loss: 0.0937 - accuracy: 0.9675 - val_loss: 0.7548 - val_accuracy: 0.7070
Epoch 24/25
20/20 [==============================] - 46s 2s/step - loss: 0.0861 - accuracy: 0.9694 - val_loss: 0.1306 - val_accuracy: 0.9435
Epoch 25/25
20/20 [==============================] - 47s 2s/step - loss: 0.1002 - accuracy: 0.9655 - val_loss: 0.0382 - val_accuracy: 0.9839
ImageDataGenerator是Keras中图像预处理的类,经过预处理使得后续的训练更加准确。
Sequential定义了序列化的神经网络,封装了神经网络的结构,有一组输入和一组输出。可以定义多个神经层,各层之间按照先后顺序堆叠,前一层的输出就是后一层的输入,通过多个层的堆叠,构建出神经网络。
神经网络两个常用的操作:卷积和池化。由于图片中可能包含干扰或者弱信息,使用卷积处理(此处的Conv2D函数)使得我们能够找到特定的局部图像特征(如边缘)。此处使用了3X3的滤波器(通常称为垂直索伯滤波器)。而池化(此处的MaxPooling2D)的作用是降低采样,因为卷积层输出中包含很多冗余信息。池化通过减小输入的大小降低输出值的数量。详细的信息可以参考知乎回答“如何理解卷积神经网络(CNN)中的卷积和池化?”。更多的卷积算法参考Github Convolution arithmetic。
Dense的操作即全连接层操作,本质就是由一个特征空间线性变换到另一个特征空间。Dense层的目的是将前面提取的特征,在dense经过非线性变化,提取这些特征之间的关联,最后映射到输出空间上。Dense这里作为输出层。
完成模型训练之后,我们绘制训练和验证结果的相关信息。
import matplotlib.pyplot as plt
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(len(acc))
plt.plot(epochs, acc, 'r', label='Training accuracy')
plt.plot(epochs, val_acc, 'b', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.legend(loc=0)
plt.figure()
plt.show()

<Figure size 640x480 with 0 Axes>
利用生成了模型,我们可以运行实际中的例子,例如上传石头剪头布的图片进行推测,使用model.predict。这里不做展开,后续我们利用Tensorflow Lite进行Android APP开发时,可以很自然的利用手机自带的摄像头或者图片库进行图片输入。
参考文献
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐


所有评论(0)