TensorFlow 模型剪枝

一、运行环境

TensorFlow-gpu > 8.0

python 3.6

numpy 1.16

CUDA 9.0  , cudnn 7.0

 

二、模型剪枝方法

模型训练时剪枝,只需选定需要剪枝的层,对于选中做剪枝的层增加一个二进制掩模(mask)变量,形状和该层的权值张量形状完全相同。该掩模决定了哪些权值参与前向计算。掩模更新算法则需要为 TensorFlow 训练计算图注入特殊运算符,对当前层权值按绝对值大小排序,对幅度小于一定门限的权值将其对应掩模值设为 0。反向传播梯度也经过掩模,被屏蔽的权值(mask 为 0)在反向传播步骤中无法获得更新量。

该方法已经在TensorFlow tensorflow.contrib.model_pruning.python中puning实现

选定的层用 tensorflow.contrib.model_pruning.python.layers 中layers替代

三、模型剪枝具体实现

如果是在全连接层做剪枝,全连接层代码写成

from tensorflow.contrib.model_pruning.python.layers import layers

fc_layer1 = layers.masked_fully_connected(ft, 200)

fc_layer2 = layers.masked_fully_connected(fc_layer1, 100)

prediction = layers.masked_fully_connected(fc_layer2, 10)

这里相当于把最后三层全连接层改写成加入掩膜

 

如果是卷积层做剪枝,卷积层代码写成

from tensorflow.contrib.model_pruning.python.layers import layers

layers.masked_conv2d(indata,kernel_size=[5,5,channel,outchannel],padding='SAME',activation_fn=nn.relu)

加入了剪枝操作后,配置如何剪枝,及剪枝稀疏度目标

# Get, Print, and Edit Pruning Hyperparameters

pruning_hparams = pruning.get_pruning_hparams()

print("Pruning Hyper parameters:", pruning_hparams)

# Change hyperparameters to meet our needs

pruning_hparams.begin_pruning_step = 0

pruning_hparams.end_pruning_step = 250

pruning_hparams.pruning_frequency = 1

pruning_hparams.sparsity_function_end_step = 250

pruning_hparams.target_sparsity = .9

# Create a pruning object using the pruning specification, sparsity seems to have priority over the hparam

p = pruning.Pruning(pruning_hparams, global_step=global_step)

prune_op = p.conditional_mask_update_op()

这里设置了稀疏度最后优化到到0.9

 

之后再模型训练时先操作剪枝,再操作训练

def train_network(self,graph,x_train,y_train):

  # prune op

  self.sess.run(graph['prune_op'])

  self.sess.run(graph['optimize'], feed_dict={graph['x']:x_train, graph['y']:y_train})

 

总的来说操作步骤:先选定需要剪枝的层,替换成相应代码,在配置剪枝参数,最后训练时先run剪枝操作,再run训练操作。

四、模型剪枝完整代码

1.第77-79行做全连接层剪枝

2.第90-100行配置剪枝参数

3.第142行加入了剪枝的sess run 之后训练,其他代码都是常规CNN代码

"""
!/usr/bin/env python
-*- coding:utf-8 -*-
Author: eric.lai
Created on 2019/5/27 10:50
"""
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
from tensorflow.contrib.model_pruning.python import pruning
from tensorflow.contrib.model_pruning.python.layers import layers
import time

class LeNet_Mode():
    """ create LeNet network use tensorflow
        LeNet network structure:
        (conv 5x5 32 ,pool/2)
        (conv 5x5 64, pool/2)
        (fc 100)=>=>(fc classes)
    """
    def conv_layer(self, data, ksize, stride, name, w_biases = False,padding = "SAME"):
        with tf.variable_scope(name,reuse=tf.AUTO_REUSE):
            w_init = tf.contrib.layers.xavier_initializer()
            w = tf.get_variable(name= name,shape= ksize, initializer= w_init)
            biases = tf.Variable(tf.constant(0.0, shape=[ksize[3]], dtype=tf.float32), 'biases')
        if w_biases == False:
            cov = tf.nn.conv2d(input= data, filter= w, strides= stride, padding= padding)
        else:
            cov = tf.nn.conv2d(input= data,filter= w, stride= stride,padding= padding) + biases
        return cov

    def pool_layer(self, data, ksize, stride, name, padding= 'VALID'):
        with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
            max_pool =  tf.nn.max_pool(value= data, ksize= ksize, strides= stride,padding= padding)
        return max_pool

    def flatten(self,data):
        [a,b,c,d] = data.get_shape().as_list()
        ft = tf.reshape(data,[-1,b*c*d])
        return ft

    def fc_layer(self,data,name,fc_dims):
        with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
            data_shape = data.get_shape().as_list()
            w_init = tf.contrib.layers.xavier_initializer()
            w = tf.get_variable(shape=[data_shape[1],fc_dims],name= 'w',initializer=w_init)
            # w = tf.Variable(tf.truncated_normal([data_shape[1], fc_dims], stddev=0.01),'w')
            biases = tf.Variable(tf.constant(0.0, shape=[fc_dims], dtype=tf.float32), 'biases')
            fc = tf.nn.relu(tf.matmul(data,w)+ biases)
        return fc

    def finlaout_layer(self,data,name,fc_dims):
        with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
            w_init = tf.contrib.layers.xavier_initializer()
            w = tf.get_variable(shape=[data.shape[1],fc_dims],name= 'w',initializer=w_init)
            biases = tf.Variable(tf.constant(0.0, shape=[fc_dims], dtype=tf.float32), 'biases')
            # fc = tf.nn.softmax(tf.matmul(data,w)+ biases)
            fc = tf.matmul(data,w)+biases
        return fc

    def model_bulid(self, height, width, channel,classes):
        x = tf.placeholder(dtype= tf.float32, shape = [None,height,width,channel])
        y = tf.placeholder(dtype= tf.float32 ,shape=[None,classes])

        # conv 1 ,if image Nx465x128x1 ,(conv 5x5 32 ,pool/2)
        conv1_1 = tf.nn.relu(self.conv_layer(x,ksize=[5,5,channel,32],stride=[1,1,1,1],padding="SAME",name="conv1_1")) # Nx465x128x1 ==>   Nx465x128x32
        pool1_1 = self.pool_layer(conv1_1,ksize=[1,2,2,1],stride=[1,2,2,1],name="pool1_1") # N*232x64x32

        # conv 2,(conv 5x5 32)=>(conv 5x5 64, pool/2)
        conv2_1 = tf.nn.relu(self.conv_layer(pool1_1,ksize=[5,5,32,64],stride=[1,1,1,1],padding="SAME",name="conv2_1"))
        pool2_1 = self.pool_layer(conv2_1,ksize=[1,2,2,1],stride=[1,2,2,1],name="pool2_1") # Nx116x32x128

        # Flatten
        ft = self.flatten(pool2_1)

        # Dense layer,(fc 100)=>=>(fc classes) and prune optimize
        fc_layer1 = layers.masked_fully_connected(ft, 200)
        fc_layer2 = layers.masked_fully_connected(fc_layer1, 100)
        prediction = layers.masked_fully_connected(fc_layer2, 10)

        loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=prediction, labels=y))
        #  original Dense layer
        # fc1 = self.fc_layer(ft,fc_dims=100,name="fc1")
        # finaloutput = self.finlaout_layer(fc1,fc_dims=10,name="final")

        #  pruning op
        global_step = tf.train.get_or_create_global_step()
        reset_global_step_op = tf.assign(global_step, 0)
        # Get, Print, and Edit Pruning Hyperparameters
        pruning_hparams = pruning.get_pruning_hparams()
        print("Pruning Hyper parameters:", pruning_hparams)
        # Change hyperparameters to meet our needs
        pruning_hparams.begin_pruning_step = 0
        pruning_hparams.end_pruning_step = 250
        pruning_hparams.pruning_frequency = 1
        pruning_hparams.sparsity_function_end_step = 250
        pruning_hparams.target_sparsity = .9
        # Create a pruning object using the pruning specification, sparsity seems to have priority over the hparam
        p = pruning.Pruning(pruning_hparams, global_step=global_step)
        prune_op = p.conditional_mask_update_op()

        # optimize
        LEARNING_RATE_BASE = 0.001
        LEARNING_RATE_DECAY = 0.9
        LEARNING_RATE_STEP = 300
        gloabl_steps = tf.Variable(0, trainable=False)
        learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE
                                                   , gloabl_steps,
                                                   LEARNING_RATE_STEP,
                                                   LEARNING_RATE_DECAY,
                                                   staircase=True)
        with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
            optimize = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss,global_step)

        # prediction
        prediction_label = prediction
        correct_prediction = tf.equal(tf.argmax(prediction_label,1),tf.argmax(y,1))
        accurary = tf.reduce_mean(tf.cast(correct_prediction,dtype=tf.float32))
        correct_times_in_batch = tf.reduce_mean(tf.cast(correct_prediction,dtype=tf.int32))

        return dict(
            x=x,
            y=y,
            optimize=optimize,
            correct_prediction=prediction_label,
            correct_times_in_batch=correct_times_in_batch,
            cost=loss,
            accurary = accurary,
            prune_op = prune_op
        )

    def init_sess(self):
        init = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
        self.sess = tf.Session()
        self.sess.run(init)

    def train_network(self,graph,x_train,y_train):
        # Tensorfolw Adding more and more nodes to the previous graph results in a larger and larger memory footprint
        # reset graph
        # tf.reset_default_graph()
        # prune op
        self.sess.run(graph['prune_op'])
        self.sess.run(graph['optimize'], feed_dict={graph['x']:x_train, graph['y']:y_train})
        # print("cost: ",self.sess.run(graph['cost'],feed_dict={graph['x']:x_train, graph['y']:y_train}))
        # print("accurary: ",self.sess.run(graph['accurary'],feed_dict={graph['x']:x_train, graph['y']:y_train}))

    def save_model(self):
        saver = tf.train.Saver()
        save_path = saver.save(self.sess,"save/model.ckpt")

    def load_data(self):
        mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
        g = self.model_bulid(28, 28, 1, 10)
        # Build the model first, then initialize it, just once
        start = time.time()
        self.init_sess()
        for epoch in range(30):
            for i in range(1500):
                batch_xs, batch_ys = mnist.train.next_batch(1000)
                batch_xs = np.reshape(batch_xs,[-1,28,28,1])
                # sess.run(g['prune_op'], feed_dict={g['x']:batch_xs, g['y']:batch_ys})
                self.train_network(g,batch_xs,batch_ys)
                print("Train cost accurary print:","cost: ", self.sess.run(g['cost'], feed_dict={g['x']: batch_xs, g['y']: batch_ys}), "accurary: ",
                      self.sess.run(g['accurary'], feed_dict={g['x']: batch_xs, g['y']: batch_ys}))
                if i % 30==0:
                    batch_xs_test, batch_ys_test = mnist.test.next_batch(1000)
                    batch_xs_test = np.reshape(batch_xs_test,[-1,28,28,1])
                    acc = self.sess.run(g['accurary'],feed_dict={g['x']:batch_xs_test, g['y']:batch_ys_test})
                    print("******Test cost accurary print******:","cost: ",self.sess.run(g['cost'],feed_dict={g['x']:batch_xs_test, g['y']:batch_ys_test}),"accurary: ",
                          self.sess.run(g['accurary'],feed_dict={g['x']:batch_xs_test, g['y']:batch_ys_test}))
                    print("Sparsity of layers (should be 0)", self.sess.run(tf.contrib.model_pruning.get_weight_sparsity()))
                    if acc > 0.9:
                        self.save_model()

        end = time.time()
        print(end-start,"min times")

if __name__ == '__main__':
    LeNet = LeNet_Mode()
    LeNet.load_data()

五,对剪裁结果查看

import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file


model_dir = "save/"

ckpt = tf.train.get_checkpoint_state(model_dir)
ckpt_path = ckpt.model_checkpoint_path

reader = pywrap_tensorflow.NewCheckpointReader(ckpt_path)
param_dict = reader.get_variable_to_shape_map()

for key, val in param_dict.items():
    try:
        print(key, val)
        print_tensors_in_checkpoint_file(ckpt_path, tensor_name=key, all_tensors=False,
                                         all_tensor_names=False)
    except:
        pass

可以看到在fully_connected1这里有个和权重矩阵大小一样的mask0-1矩阵,0代表是改位置剪裁掉的数据,1代表保留了数据。因此用权重矩阵和这个矩阵点乘就能得到剪枝后的权重矩阵。

GitHub 加速计划 / te / tensorflow
184.55 K
74.12 K
下载
一个面向所有人的开源机器学习框架
最近提交(Master分支:2 个月前 )
a49e66f2 PiperOrigin-RevId: 663726708 3 个月前
91dac11a This test overrides disabled_backends, dropping the default value in the process. PiperOrigin-RevId: 663711155 3 个月前
Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐