ACGAN(Auxiliary Classifier GAN)详解与实现(tensorflow2.x实现)
ACGAN(Auxiliary Classifier GAN)详解与实现(tensorflow2.x实现)
ACGAN原理
ACGAN的原理GAN(CGAN)
相似。对于CGAN和ACGAN,生成器输入均为潜在矢量及其标签,输出是属于输入类标签的伪造图像。对于CGAN,判别器的输入是图像(包含假的或真实的图像)及其标签, 输出是图像属于真实图像的概率。对于ACGAN,判别器的输入是一幅图像,而输出是该图像属于真实图像的概率以及其类别概率。
本质上,在CGAN中,向网络提供了标签。在ACGAN中,使用辅助解码器网络重建辅助信息。ACGAN理论认为,强制网络执行其他任务可以提高原始任务的性能。在这种情况下,辅助任务是图像分类。原始任务是生成伪造图像。
判别器目标函数:
L
(
D
)
=
−
E
x
∼
p
d
a
t
a
l
o
g
D
(
x
)
−
E
z
l
o
g
[
1
−
D
(
G
(
z
∣
y
)
)
]
−
E
x
∼
p
d
a
t
a
p
(
c
∣
x
)
−
E
z
l
o
g
p
(
c
∣
g
(
z
∣
y
)
)
\mathcal L^{(D)} = -\mathbb E_{x\sim p_{data}}logD(x)-\mathbb E_zlog[1 − D(G(z|y))]-\mathbb E_{x\sim p_{data}}p(c|x)-\mathbb E_zlogp(c|g(z|y))
L(D)=−Ex∼pdatalogD(x)−Ezlog[1−D(G(z∣y))]−Ex∼pdatap(c∣x)−Ezlogp(c∣g(z∣y))
生成器目标函数:
L
(
G
)
=
−
E
z
l
o
g
D
(
g
(
z
∣
y
)
)
−
E
z
l
o
g
p
(
c
∣
g
(
z
∣
y
)
)
\mathcal L^{(G)} = -\mathbb E_{z}logD(g(z|y))-\mathbb E_zlogp(c|g(z|y))
L(G)=−EzlogD(g(z∣y))−Ezlogp(c∣g(z∣y))
ACGAN实现
模块导入
import tensorflow as tf
from tensorflow import keras
import numpy as np
from matplotlib import pyplot as plt
import os
import math
from PIL import Image
生成器
def generator(inputs,image_size,activation='sigmoid',labels=None):
"""生成网络
Arguments:
inputs (layer): 输入
image_size (int): 图片尺寸
activation (string): 输出层激活函数
labels (tensor): 标签
returns:
model: 生成网络
"""
image_resize = image_size // 4
kernel_size = 5
layer_filters = [128,64,32,1]
inputs = [inputs,labels]
x = keras.layers.concatenate(inputs,axis=1)
x = keras.layers.Dense(image_resize*image_resize*layer_filters[0])(x)
x = keras.layers.Reshape((image_resize,image_resize,layer_filters[0]))(x)
for filters in layer_filters:
if filters > layer_filters[-2]:
strides = 2
else:
strides = 1
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Activation('relu')(x)
x = keras.layers.Conv2DTranspose(filters=filters,
kernel_size=kernel_size,
strides=strides,
padding='same')(x)
if activation is not None:
x = keras.layers.Activation(activation)(x)
return keras.Model(inputs,x,name='generator')
鉴别器
def discriminator(inputs,activation='sigmoid',num_labels=None):
"""生成网络
Arguments:
inputs (Layer): 输入
activation (string): 输出层激活函数
num_labels (int): 类别数
Returns:
Model: 鉴别网络
"""
kernel_size = 5
layer_filters = [32,64,128,256]
x = inputs
for filters in layer_filters:
if filters == layer_filters[-1]:
strides = 1
else:
strides = 2
x = keras.layers.LeakyReLU(0.2)(x)
x = keras.layers.Conv2D(filters=filters,
kernel_size=kernel_size,
strides=strides,
padding='same')(x)
x = keras.layers.Flatten()(x)
outputs = keras.layers.Dense(1)(x)
if activation is not None:
print(activation)
outputs = keras.layers.Activation(activation)(outputs)
if num_labels:
#ACGAN有第二个输出,用于输出图片的类别
layer = keras.layers.Dense(layer_filters[-2])(x)
labels = keras.layers.Dense(num_labels)(layer)
labels = keras.layers.Activation('softmax',name='label')(labels)
outputs = [outputs,labels]
return keras.Model(inputs,outputs,name='discriminator')
模型构建
def build_and_train_models():
"""The ACGAN training
"""
#数据加载及预处理
(x_train,y_train),_ = keras.datasets.mnist.load_data()
image_size = x_train.shape[1]
x_train = np.reshape(x_train,[-1,image_size,image_size,1])
x_train = x_train.astype('float32') / 255.
num_labels = len(np.unique(y_train))
y_train = keras.utils.to_categorical(y_train)
#超参数
model_name = 'acgan-mnist'
latent_size = 100
batch_size = 64
train_steps = 40000
lr = 2e-4
decay = 6e-8
input_shape = (image_size,image_size,1)
label_shape = (num_labels,)
#discriminator
inputs = keras.layers.Input(shape=input_shape,name='discriminator_input')
discriminator = discriminator(inputs,num_labels=num_labels)
optimizer = keras.optimizers.RMSprop(lr=lr,decay=decay)
loss = ['binary_crossentropy','categorical_crossentropy']
discriminator.compile(loss=loss,optimizer=optimizer,metrics=['acc'])
discriminator.summary()
#generator
input_shape = (latent_size,)
inputs = keras.layers.Input(shape=input_shape,name='z_input')
labels = keras.layers.Input(shape=label_shape,name='labels')
generator = generator(inputs,image_size,labels=labels)
generator.summary()
optimizer = keras.optimizers.RMSprop(lr=lr*0.5,decay=decay*0.5)
discriminator.trainable = False
adversarial = keras.Model([inputs,labels],discriminator(generator([inputs,labels])),
name=model_name)
adversarial.compile(loss=loss,optimizer=optimizer,metrics=['acc'])
adversarial.summary()
models = (generator,discriminator,adversarial)
data = (x_train,y_train)
params = (batch_size,latent_size,train_steps,num_labels,model_name)
train(models,data,params)
模型训练
def train(models,data,params):
"""Train the discriminator and adversarial Networks
Arguments:
models (list): generator,discriminator,adversarial
data (list): x_train,y_train
params (list): network parameter
"""
generator,discriminator,adversarial = models
x_train,y_train = data
batch_size,latent_size,train_steps,num_labels,model_name = params
save_interval = 500
noise_input = np.random.uniform(-1.,1.,size=[16,latent_size])
noise_label = np.eye(num_labels)[np.arange(0,16) % num_labels]
train_size = x_train.shape[0]
print(model_name,'Labels for generated images: ',np.argmax(noise_label,axis=1))
for i in range(train_steps):
#训练鉴别器
rand_indexes = np.random.randint(0,train_size,size=batch_size)
real_images = x_train[rand_indexes]
real_labels = y_train[rand_indexes]
#产生伪造图片
noise = np.random.uniform(-1.,1.,size=(batch_size,latent_size))
fake_labels = np.eye(num_labels)[np.random.choice(num_labels,batch_size)]
fake_images = generator.predict([noise,fake_labels])
#构造输入
x = np.concatenate((real_images,fake_images))
#训练类别标签
labels = np.concatenate((real_labels,fake_labels))
#标签
y = np.ones([2*batch_size,1])
y[batch_size:,:] = 0.0
#训练模型
metrics = discriminator.train_on_batch(x,[y,labels])
fmt = '%d: [disc loss: %f, srcloss: %f],'
fmt += 'lbloss: %f, srcacc: %f, lblacc: %f'
log = fmt % (i,metrics[0],metrics[1],metrics[2],metrics[3],metrics[4])
#train adversarial network for 1 batch
noise = np.random.uniform(-1.,1.,size=(batch_size,latent_size))
fake_labels = np.eye(num_labels)[np.random.choice(num_labels,batch_size)]
y = np.ones([batch_size,1])
metrics = adversarial.train_on_batch([noise,fake_labels],[y,fake_labels])
fmt = "%s [advr loss: %f, srcloss: %f,"
fmt += "lblloss: %f, srcacc: %f, lblacc: %f]"
log = fmt % (log, metrics[0], metrics[1], metrics[2], metrics[3], metrics[4])
print(log)
if (i + 1) % save_interval == 0:
# 绘制生成图片
plot_images(generator,noise_input=noise_input,
noise_label=noise_label,show=False,
step=(i + 1),
model_name=model_name)
generator.save(model_name + ".h5")
虚假图像生成及绘制plot_images函数
def plot_images(generator,
noise_input,
noise_label=None,
noise_codes=None,
show=False,
step=0,
model_name="gan"):
"""生成虚假图片及绘制
# Arguments
generator (Model): 生成模型
noise_input (ndarray): 潜在模型
show (bool): 是否展示
step (int): step值
model_name (string): 模型名称
"""
os.makedirs(model_name, exist_ok=True)
filename = os.path.join(model_name, "%05d.png" % step)
rows = int(math.sqrt(noise_input.shape[0]))
if noise_label is not None:
noise_input = [noise_input, noise_label]
if noise_codes is not None:
noise_input += noise_codes
images = generator.predict(noise_input)
plt.figure(figsize=(2.2, 2.2))
num_images = images.shape[0]
image_size = images.shape[1]
for i in range(num_images):
plt.subplot(rows, rows, i + 1)
image = np.reshape(images[i], [image_size, image_size])
plt.imshow(image, cmap='gray')
plt.axis('off')
plt.savefig(filename)
if show:
plt.show()
else:
plt.close('all')
训练结果
#运行
if __name__ == '__main__':
build_and_train_models()
step=1000:
step=15000:
更多推荐
所有评论(0)