tensorflow estimator详细介绍,实现模型的高效训练
tensorflow
一个面向所有人的开源机器学习框架
项目地址:https://gitcode.com/gh_mirrors/te/tensorflow
免费下载资源
·
estimator
是tensorflow高度封装的一个类,里面有一些可以直接使用的分类和回归模型,例如tf.estimator.DNNClassifier
,但这不是这篇博客的主题,而是怎么使用estimator
来实现我们自定义模型的训练。它的步骤主要分为以下几个部分:
- 构建
model_fn
,在这个方法里面定义自己的模型以及训练和测试过程要做的事情; - 构建
input_fn
,在这个方法数据的来源和喂给模型的方式; - 最后,创建
estimator
对象,然后开始训练模型了。可以添加一些config,比如:loss的输出频率等。
构建model_fn方法
import tensorflow as tf
def model_fn(features, labels, mode, params): # 必须要有前面三个参数
# feature和labels其实就是`input_fn`方法传输过来的
# mode是用来判断你现在是训练或测试阶段
# params是在创建`estimator`对象的输入参数
lr = params['lr']
try:
init_checkpoint = params['init_checkpoint']
except KeyError:
init_checkpoint = None
x = features['inputs']
y = features['labels']
#####################在这里定义你自己的网络模型###################
pre = tf.layers.dense(x, 1)
loss = tf.reduce_mean(tf.pow(pre-y, 2), name='loss')
######################在这里定义你自己的网络模型###################
# 这里可以加载你的预训练模型
assignment_map = dict()
if init_checkpoint:
for var in tf.train.list_variables(init_checkpoint): # 存放checkpoint的变量名称和shape
assignment_map[var[0]] = var[0]
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
# 定义你训练过程要做的事情
if mode == tf.estimator.ModeKeys.TRAIN:
optimizer = tf.train.AdamOptimizer(lr)
train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
output_spec = tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
# 定义你测试(验证)过程
elif mode == tf.estimator.ModeKeys.EVAL:
metrics = {'eval_loss': tf.metrics.mean_tensor(loss), "accuracy": tf.metrics.accuracy(labels, pre)}
output_spec = tf.estimator.EstimatorSpec(mode, loss=loss, eval_metric_ops=metrics)
# 定义你的预测过程
elif mode == tf.estimator.ModeKeys.PREDICT:
predictions = {'predictions': pre}
output_spec = tf.estimator.EstimatorSpec(mode, predictions=predictions)
else:
raise TypeError
return output_spec
提几点需要注意的地方:
model_fn
方法返回的是tf.estimator.EstimatorSpec
;- TRAIN、EVAL和PREDICT模式不可缺少的参数是不一样的。
构建input_fn方法
def input_fn_bulider(inputs_file, batch_size, is_training):
name_to_features = {'inputs': tf.FixedLenFeature([3], tf.float32),
'labels': tf.FixedLenFeature([], tf.float32)}
def input_fn(params):
d = tf.data.TFRecordDataset(inputs_file)
if is_training:
d = d.repeat()
d = d.shuffle()
# map_and_batch其实就是将map和batch结合起来而已
d = d.apply(tf.contrib.data.map_and_batch(lambda x: tf.parse_single_example(x, name_to_features),
batch_size=batch_size))
return d
return input_fn
执行eatimator
if __name__ == '__main':
# 定义日志消息的输出级别,为了获取模型的反馈信息,选择INFO
tf.logging.set_verbosity(tf.logging.INFO)
# 我在这里是指定模型的保存和loss输出频率
runConfig = tf.estimator.RunConfig(save_checkpoints_steps=1,
log_step_count_steps=1)
estimator = tf.estimator.Estimator(model_fn, model_dir='your_save_path',
config=runConfig, params={'lr': 0.01})
# log_step_count_steps控制的只是loss的global_step的输出
# 我们还可以通过tf.train.LoggingTensorHook自定义更多的输出
# tensor是我们要输出的内容,输入一个字典,key为打印出来的名称,value为你要输出的tensor的name
logging_hook = tf.train.LoggingTensorHook(every_n_iter=1,
tensors={'loss': 'loss'})
# 其实给到estimator.train是一个dataset对象
input_fn = input_fn_bulider('test.tfrecord', batch_size=1, is_training=True)
estimator.train(input_fn, max_steps=1000)
# 下面你还可以对模型进行验证和测试,做法是差不多的,我就不列举了
欢迎关注同名公众号:“我就算饿死也不做程序员”。
交个朋友,一起交流,一起学习,一起进步。
GitHub 加速计划 / te / tensorflow
184.55 K
74.12 K
下载
一个面向所有人的开源机器学习框架
最近提交(Master分支:2 个月前 )
a49e66f2
PiperOrigin-RevId: 663726708
3 个月前
91dac11a
This test overrides disabled_backends, dropping the default
value in the process.
PiperOrigin-RevId: 663711155
3 个月前
更多推荐
已为社区贡献7条内容
所有评论(0)