基于单词字符的卷积文本识别

!python --version
Python 3.10.20
import numpy as np
print(np.__version__)
1.23.5
import torch
print(torch.__version__)
`2.2.2

完整代码

from sklearn.model_selection import train_test_split   # pip3 install scikit-learn
import csv
import numpy as np
import re
import torch
import einops.layers.torch as elt

def text_clearTitle(text):
    text = text.lower()                     # 将文本转换成小写
    text = re.sub(r"[^a-z]"," ",text)    # 替换非标准字符,^是求反操作
    text = re.sub(r" +"," ",text)           # 替换多重空格
    text = text.strip()                     # 取出首尾空格
    text = text + " eos"
    return text

def get_one_hot(list,alphabet_title = None):
    if alphabet_title == None:    # 设置字符集
       alphabet_title = "abcdefghigklmnopqrstuvwxyz "
    else:
        alphabet_title = alphabet_title
    values = np.array(list)             #获取字符数列
    n_values = len(alphabet_title)  #获取字符表长度
    return np.eye(n_values)[values]

def get_label_one_hot(list):
    values = np.array(list)
    n_values = np.max(values) + 1
    return np.eye(n_values)[values]
    
def get_char_list(string):
    alphabet_title = "abcdefghijklmnopqrstuvwxyz "
    char_list = []
    for char in string:
        num = alphabet_title.index(char)
        char_list.append(num)
    return char_list

def get_one_hot(list,alphabet_title = None):
    if alphabet_title == None:    # 设置字符集
       alphabet_title = "abcdefghijklmnopqrstuvwxyz "
    else:
        alphabet_title = alphabet_title
    values = np.array(list)             #获取字符数列
    n_values = len(alphabet_title) + 1  #获取字符表长度
    return np.eye(n_values)[values]

def get_string_matrix(string):
    char_list = get_char_list(string)
    string_matrix = get_one_hot(char_list)
    return string_matrix

def get_handle_string_matrix(string, n=64):
    string_length = len(string)
    #print(string_length)
    if string_length > 64:
        string = string[:64]
        string_matrix = get_string_matrix(string)
        return string_matrix
    else:
        string_matrix = get_string_matrix(string)
        #print("字符长度为=",string_length,"title拆成数组为:",string_matrix.shape)
        handle_length = n - string_length
        #print("handle_length=",handle_length)
        pad_matrix = np.zeros([handle_length,28])
        string_matrix = np.concatenate([string_matrix,pad_matrix])
        return string_matrix

def char_CNN(input_dim = 28):    
     # torch.nn.Sequential 是 PyTorch 中用于快速构建顺序模型的容器
    model = torch.nn.Sequential(
        # 第一层卷积
        elt.Rearrange("b l c -> b c l"),
        torch.nn.Conv1d(input_dim,32,kernel_size=3,padding=1),
        elt.Rearrange("b c l -> b l c"),
        torch.nn.ReLU(),
        torch.nn.LayerNorm(32),

        #第二层卷积
        elt.Rearrange("b l c -> b c l"),
        torch.nn.Conv1d(32, 28, kernel_size=3, padding=1),
        elt.Rearrange("b c l -> b l c"),
        torch.nn.ReLU(),
        torch.nn.LayerNorm(28),

        #flatten
        torch.nn.Flatten(),  #[batch_size,64 * 28]
        torch.nn.Linear(64 * 28,64),
        torch.nn.ReLU(),

        torch.nn.Linear(64,5),
        torch.nn.Softmax()   
    )
    return model


def get_dataset():
    agnews_label = []
    agnews_title = []
    agnews_train = csv.reader(open("./dataset/train.csv","r"))
    for line in agnews_train:
        agnews_label.append(np.int32(line[0]))
        agnews_title.append(text_clearTitle(line[1]))
    train_dataset = []
    for title in agnews_title:
        string_matrix = get_handle_string_matrix(title)
        train_dataset.append(string_matrix)
    
    train_dataset = np.array(train_dataset,float)
    label_dataset = get_label_one_hot(agnews_label)

    return train_dataset,label_dataset

train_dataset,label_dataset = get_dataset()
X_train,X_test, y_train, y_test = train_test_split(train_dataset,label_dataset,test_size=0.1, random_state=828)  #将数据集划分为训练集和测试集

print(X_train.shape)
print(X_test.shape)
print(y_train.shape)
print(y_test.shape)
#获取device
device = "cuda" if torch.cuda.is_available() else "cpu"
model = char_CNN().to(device)

# 定义交叉熵损失函数
def cross_entropy(pred, label):
    res = -torch.sum(label * torch.log(pred)) / label.shape[0]
    return torch.mean(res)
    
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

batch_size = 128
train_num = len(X_test)//128
for epoch in range(99):
    train_loss = 0.
    for i in range(train_num):
        start = i * batch_size
        end = (i + 1) * batch_size

        x_batch = torch.tensor(X_train[start:end]).type(torch.float32).to(device)
        y_batch = torch.tensor(y_train[start:end]).type(torch.float32).to(device)

        pred = model(x_batch)
        loss = cross_entropy(pred, y_batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()  # 记录每个批次的损失值

    # 计算并打印损失值
    train_loss /= train_num
    accuracy = (pred.argmax(1) == y_batch.argmax(1)).type(torch.float32).sum().item() / batch_size
    print("epoch:",epoch,"train_loss:", round(train_loss,2),"accuracy:",round(accuracy,2))
(108000, 64, 28)
(12000, 64, 28)
(108000, 5)
(12000, 5)
epoch: 0 train_loss: 1.4 accuracy: 0.37
epoch: 1 train_loss: 1.32 accuracy: 0.39
epoch: 2 train_loss: 1.27 accuracy: 0.45
...
epoch: 96 train_loss: 0.02 accuracy: 1.0
epoch: 97 train_loss: 0.02 accuracy: 1.0
epoch: 98 train_loss: 0.02 accuracy: 1.0

数据读取处理

文本主题提取:基于TF-IDF博文中 详细介绍了对Ag-news数据集的读取、清洗等操作步骤

文本的One-Hot

1. 将字符串按字符表的顺序转换成数字序列

def get_char_list(string):
    alphabet_title = "abcdefghijklmnopqrstuvwxyz "
    char_list = []
    for char in string:
        if char in alphabet_title:
            num = alphabet_title.index(char)
            char_list.append(num)
    return char_list

print(get_char_list("hello"))
char_list = get_char_list("wall st bears claw back into the black reuterseos")
print(char_list)
[7, 4, 11, 11, 14]
[22, 0, 11, 11, 26, 18, 19, 26, 1, 4, 0, 17, 18, 26, 2, 11, 0, 22, 26, 1, 0, 2, 10, 26, 8, 13, 19, 14, 26, 19, 7, 4, 26, 1, 11, 0, 2, 10, 26, 17, 4, 20, 19, 4, 17, 18, 4, 14, 18]

2. 文本的One-Hot 处理

针对不同的字符获取字符表对应位置进行提取,根据提取的位置将对应的字符位置设置成1,其他为0.

import numpy as np
def get_one_hot(list,alphabet_title = None):
    if alphabet_title == None:    # 设置字符集
       alphabet_title = "abcdefghigklmnopqrstuvwxyz "
    else:
        alphabet_title = alphabet_title
    values = np.array(list)             #获取字符数列
    n_values = len(alphabet_title)  #获取字符表长度
    return np.eye(n_values)[values]

char_list = get_char_list("hello")
string_matrix = get_one_hot(char_list)

print(string_matrix)
print(string_matrix.shape)
[[0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0.]]
(5, 27)

[ 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ] \begin{bmatrix} 0.& 0.& 0.& 0.& 0.& 0.& 0.& 1.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.\\ 0.& 0.& 0.& 0.& 1.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.\\ 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 1.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.\\ 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 1.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.\\ 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 1.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0.& 0. \end{bmatrix} 0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.1.0.0.0.0.0.0.0.0.0.0.0.0.0.1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.1.1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.

np.eye()

numpy.eye(N, M=None, k=0, dtype=<class 'float'>, order='C', *, like=None)

N:生成的矩阵的行数。
M:生成的矩阵的列数(可选,默认为 N)。
k:对角线的索引(可选,默认为0,即主对角线,k>0 为上对角线,k<0 为下对角线)。
dtype:数组的数据类型(可选,默认为 float)

生成对角方阵
np.eye(3)
array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]])
np.eye(3,k=1) # k=1表示 1的偏移量
array([[0., 1., 0.],
       [0., 0., 1.],
       [0., 0., 0.]])
生成M*N的对角矩阵
np.eye(3,5)
array([[1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.]])
np.eye(3,5,k=1)
array([[0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.]])
将索引数组转换为 one-hot 编码数组
# 将索引数组转换为 one-hot 编码数组 
# 0作为索引起点
values = [1, 0, 3]
n_values = np.max(values) + 1
np.eye(n_values)[values]
array([[0., 1., 0., 0.],
       [1., 0., 0., 0.],
       [0., 0., 0., 1.]])

3. 生成One-Hot矩阵

import csv
import re
agnews_train = csv.reader(open("./dataset/train.csv","r"))
agnews_title = []

def text_clearTitle(text):
    text = text.lower()                     # 将文本转换成小写
    text = re.sub(r"[^a-z]"," ",text)    # 替换非标准字符,^是求反操作
    text = re.sub(r" +"," ",text)           # 替换多重空格
    text = text.strip()                     # 取出首尾空格
    text = text + " eos"
    return text

def get_char_list(string):
    alphabet_title = "abcdefghijklmnopqrstuvwxyz "
    char_list = []
    for char in string:
        num = alphabet_title.index(char)
        char_list.append(num)
    return char_list

def get_one_hot(list,alphabet_title = None):
    if alphabet_title == None:    # 设置字符集
       alphabet_title = "abcdefghijklmnopqrstuvwxyz "
    else:
        alphabet_title = alphabet_title
    values = np.array(list)             #获取字符数列
    n_values = len(alphabet_title) + 1  #获取字符表长度
    return np.eye(n_values)[values]

def get_string_matrix(string):
    char_list = get_char_list(string)
    string_matrix = get_one_hot(char_list)
    return string_matrix

for line in agnews_train:
    agnews_title.append(text_clearTitle(line[1]))
    
for title in agnews_title[17:19]:
    print(title)
    string_matrix = get_string_matrix(title)
    print(string_matrix.shape)
in a down market head toward value funds eos
(44, 28)
us trade deficit swells in june eos
(35, 28)

其中28表示:26个字符,一个空格,一个label位

4. 矩阵补全

对于不同长度的矩阵,进行规范化处理,即长的截短,短的补长。

def get_handle_string_matrix(string, n=64):
    string_length = len(string)
    #print(string_length)
    if string_length > 64:
        string = string[:64]
        string_matrix = get_string_matrix(string)
        return string_matrix
    else:
        string_matrix = get_string_matrix(string)
        #print("字符长度为=",string_length,"title拆成数组为:",string_matrix.shape)
        handle_length = n - string_length
        #print("handle_length=",handle_length)
        pad_matrix = np.zeros([handle_length,28])
        string_matrix = np.concatenate([string_matrix,pad_matrix])
        return string_matrix
        
for title in agnews_title[17:19]:
    print(title)
    string_matrix = get_handle_string_matrix(title)
    print(string_matrix.shape)
    print(type(string_matrix))
    print(string_matrix)
in a down market head toward value funds eos
(64, 28)
<class 'numpy.ndarray'>
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 1. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
us trade deficit swells in june eos
(64, 28)
<class 'numpy.ndarray'>
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 1. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]

np.concatenate() 合并矩阵

numpy.concatenate((a1, a2, ...), axis=0, out=None, dtype=None, casting="same_kind")

核心功能:沿指定轴连接一个或多个数组,生成新数组
核心参数:
(a1, a2, …):必填,待连接的数组序列(需为相同数据类型或可自动转换)
axis:可选,指定连接轴(0为默认值,代表按垂直上下方向拼接)
dtype:可选,指定输出数组的数据类型(不指定时自动推断)

一维矩阵(序列水平左右拼接)
# 一维矩阵 按行拼接
import numpy as np
a = np.array([1, 2, 3])
b = np.array([4, 5, 6])
c = np.concatenate((a, b),axis=0)
print(c)         # 输出:[1 2 3 4 5 6]
print(c.shape)   # 输出:(6,)
[1 2 3 4 5 6]
(6,)
二维矩阵 垂直上下拼接 axis=0

[ 1 2 3 4 ] \begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix} [1324] [ 5 6 7 8 ] \begin{bmatrix} 5 & 6\\ 7 & 8 \end{bmatrix} [5768] 垂直上下拼接

[ 1 2 3 4 5 6 7 8 ] \begin{bmatrix} 1 & 2\\ 3 & 4\\ 5 & 6\\ 7 & 8\\ \end{bmatrix} 13572468

arr1 = np.array([[1, 2], [3, 4]])
arr2 = np.array([[5, 6], [7, 8]])
row_concat = np.concatenate((arr1, arr2), axis=0)
print(row_concat)
print(row_concat.shape)
[[1 2]
 [3 4]
 [5 6]
 [7 8]]
(4, 2)
二维矩阵 水平左右拼接 axis=1

[ 1 2 3 4 ] \begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix} [1324] [ 5 6 7 8 ] \begin{bmatrix} 5 & 6\\ 7 & 8 \end{bmatrix} [5768] 水平左右拼接

[ 1 2 5 6 3 4 7 8 ] \begin{bmatrix} 1 & 2 &5 & 6 \\ 3 & 4 &7 & 8\\ \end{bmatrix} [13245768]

arr1 = np.array([[1, 2], [3, 4]])
arr2 = np.array([[5, 6], [7, 8]])
row_concat = np.concatenate((arr1, arr2), axis=1)
print(row_concat)
print(row_concat.shape)
[[1 2 5 6]
 [3 4 7 8]]
(2, 4)
三维矩阵 按深度拼接 axis=0
#假设处理视频数据(形状为(帧数, 高度, 宽度)),需按帧数合并:
video1 = np.random.rand(2, 100, 100)   # 2帧100x100图像
video2 = np.random.rand(3, 100, 100)   # 3帧同尺寸图像
merged_video = np.concatenate((video1, video2), axis=0)
print(merged_video.shape)  # 输出:(5, 100, 100)(帧数合并,高度宽度不变)
(5, 100, 100)

参数详解

sklearn.model_selection.train_test_split()

X_train,X_test,y_train,y_test=sklearn.model_selection.train_test_split(
                                    train_data,
                                    train_target,
                                    test_size=0.4,
                                    random_state=0,
                                    stratify=y_train)
  • train_data:所要划分的样本特征集
  • train_target:所要划分的样本结果
  • test_size:样本占比,如果是整数的话就是样本的数量
  • random_state:是随机数的种子。

    随机数种子:其实就是该组随机数的编号,在需要重复试验的时候,保证得到一组一样的随机数。比如你每次都填1,其他参数一样的情况下你得到的随机数组是一样的。但填0或不填,每次都会不一样。
  • stratify是为了保持split前类的分布。比如有100个数据,80个属于A类,20个属于B类。

    如果train_test_split(… test_size=0.25, stratify = y_all), 那么split之后数据如下:

    training: 75个数据,其中60个属于A类,15个属于B类。

    testing: 25个数据,其中20个属于A类,5个属于B类。

    用了stratify参数,training集和testing集的类的比例是 A:B= 4:1,等同于split前的比例(80:20)。

    通常在这种类分布不平衡的情况下会用到stratify。
    将stratify=X就是按照X中的比例分配
    将stratify=y就是按照y中的比例分配

elt.Rearrange()

elt.Rearrange("b l c -> b c l") 是使用 ‌einops‌ 库中的 Rearrange 操作对张量维度进行重排,具体含义如下:

功能说明
‌输入张量形状‌:(b, l, c)

  • b:batch size(批次大小)
  • l:序列长度或 token 数量(如语言模型中的词元数,或时间步)
  • c:通道数或特征维度(如 embedding 维度)

输出张量形状‌:(b, c, l)

  • 将第2维(l)和第3维(c)的位置互换。
  • 这相当于对最后两个维度进行 ‌转置‌(transpose),但以声明式方式表达,更清晰、不易出错。

torch.nn.Conv1d()

torch.nn.Conv1d() 是 PyTorch 中用于‌一维卷积操作‌的核心模块,广泛应用于处理‌序列数据‌,如文本、音频、时间序列等。

核心参数说明

  • in_channels‌:输入信号的通道数(如音频声道数、词向量维度
  • out_channels‌:输出通道数,即卷积核的数量。
  • kernel_size‌:卷积核的大小(整数或元组)。
  • stride‌:卷积步长,默认为 1。
  • padding‌:输入两侧填充的长度,默认为 0。
  • dilation‌:空洞卷积的膨胀系数,默认为 1。
  • groups‌:分组卷积设置,默认为 1(标准卷积)。
  • bias‌:是否添加偏置项,默认为 True。
  • padding_mode‌:填充模式,默认为 ‘zeros’。

输入与输出形状

  • 输入形状‌:(batch_size, in_channels, length)
  • 输出形状‌:(batch_size, out_channels, L_out)
  • 其中,输出长度 L_out 的计算公式为:

torch.nn.ReLU()

torch.nn.ReLU() 是 PyTorch 中最常用的激活函数之一,它对输入张量逐元素执行修正线性单元ReLU(Rectified Linear Unit)运算

torch.nn.LayerNorm()

torch.nn.LayerNorm 是 PyTorch 中的层归一化模块。

与批归一化不同,层归一化在单个样本的特征维度上进行归一化,不依赖 batch size。

参数说明:

  • normalized_shape (int 或 list): 需要归一化的维度。
  • eps (float): 数值稳定性的epsilon。默认为 1e-5。
  • elementwise_affine (bool): 是否使用可学习的缩放和偏移。默认为 True。

torch.nn.Flatten()

torch.nn.Flatten 是 PyTorch 中的张量展平模块。
它将多维张量展平为一维,常用于卷积层和全连接层之间的连接。

torch.nn.Flatten(start_dim=1, end_dim=-1)
参数说明:

  • start_dim (int): 展平开始的维度。默认为 1(保留 batch 维度)
  • end_dim (int): 展平结束的维度。默认为 -1(到最后一维)

torch.nn.Linear()

torch.nn.Linear 是 PyTorch 中用于创建全连接层(也称为线性层或仿射变换)的模块。
它是神经网络中最基础也是最常用的层之一,负责将输入特征线性变换到输出特征空间。

函数定义
torch.nn.Linear(in_features, out_features, bias=True)

参数说明:

  • in_features (int): 输入特征的维度,即上一层输出的特征数。
  • out_features (int): 输出特征的维度,即本层输出的特征数。
  • bias (bool): 是否添加偏置项。默认为 True。如果设置为 False,则该层不会学习偏置参数。

属性:

  • weight (Tensor): 形状为 (out_features, in_features) 的可学习权重矩阵
  • bias (Tensor): 形状为 (out_features,) 的可学习偏置向量。如果 bias=False,则不存在此属性

torch.nn.Softmax()

torch.nn.Softmax 是 PyTorch 中的 Softmax 激活函数。
它将输入转换为概率分布,所有输出之和为 1。

函数定义

torch.nn.Softmax(dim=None)

参数:

  • dim: 进行 Softmax 的维度

torch.optim 优化器模块

优化器是深度学习中的核心组件,负责根据损失函数的梯度调整模型参数,使模型能够逐步逼近最优解。
在 PyTorch 中,torch.optim 模块提供了多种优化算法的实现,是训练神经网络不可或缺的工具。

为什么需要优化器

优化器在深度学习中扮演着至关重要的角色,它解决了手动更新参数的繁琐问题。

  • 自动化参数更新:手动计算和更新每个参数非常繁琐,优化器自动完成这一工作
  • 加速收敛:使用优化算法比普通梯度下降更快找到最优解
  • 避免局部最优:某些优化器具有跳出局部最优的能力

常见优化器类型

不同优化器适用于不同场景,选择合适的优化器可以显著提升训练效果。

优化器名称 主要特点 适用场景
SGD 简单基础,可带动量 基础教学、简单模型、CNN
Adam 自适应学习率 大多数深度学习任务
AdamW Adam + 权重衰减分离 需要 L2 正则化的任务
RMSprop 自适应学习率 RNN 网络、语音识别
Adagrad 参数独立学习率 稀疏数据、文本处理
Adadelta 自适应学习率 长期训练任务

优化器核心 API

掌握优化器的基本使用流程是深度学习的第一步。

基本使用流程

优化器的使用遵循固定模式:创建实例 → 清空梯度 → 反向传播 → 更新参数。

关键方法说明

优化器提供了几个核心方法来管理参数更新过程。

  • zero_grad(set_to_none=True):清空参数的梯度缓存。设置为 True 时会将梯度设为 None,比设为 0 更节省显存
  • step():执行单次参数更新,根据梯度和学习率更新模型参数
  • state_dict():获取优化器状态字典,可用于保存检查点
  • load_state_dict(state_dict):加载优化器状态,用于恢复训练
  • add_param_group(param_group):动态添加参数组

注意:必须在每次反向传播前调用 zero_grad(),否则梯度会累积,导致训练不稳定。建议使用 zero_grad(set_to_none=True) 以节省显存

torch.optim.Adam()

Adam 是目前最常用的优化器之一,结合了动量和自适应学习率的优点。它通过计算梯度的一阶和二阶矩估计来自适应调整每个参数的学习率。

optimizer = optim.Adam(
    params=model.parameters(),
    lr=0.001,                      # 推荐使用较小的学习率
    betas=(0.9, 0.999),            # 常用的动量参数
    eps=1e-8,                      # 数值稳定项
    weight_decay=1e-4,             # L2 正则化
    amsgrad=False                  # 是否使用 AMSGrad
)
# Adam 优化器参数说明
# params: 要优化的参数
# lr: 学习率,默认 0.001(推荐值)
# betas: 用于计算梯度和梯度平方的指数移动平均系数 (beta1, beta2) 
#         beta1 控制一阶矩估计(动量),默认 0.9
#         beta2 控制二阶矩估计(方差),默认 0.999
# eps: 数值稳定项,防止除零错误,默认 1e-8
# weight_decay: L2 正则化系数,默认 0
# amsgrad: 是否使用 AMSGrad 变体,默认 False

核心参数说明:

  • betas (Tuple[float, float]):控制梯度和梯度平方的指数移动平均
  • eps (float):数值稳定项,防止分母为零
  • amsgrad (bool):是否使用 AMSGrad 变体,使用后可保证收敛性

特点:

  • 自适应学习率:根据参数的历史梯度自动调整学习率
  • 结合动量概念:利用一阶矩估计加速收敛
  • 鲁棒性强:对超参数选择相对不敏感
  • 收敛速度快,适合快速原型开发

Adam 是大多数深度学习任务的默认选择,但在某些特定场景(如 GAN、强化学习)下可能需要尝试其他优化器。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐