tensorflow2.0 --- api之model.fit
https://tensorflow.google.cn/api_docs/python/tf/keras/Model?version=stable#fit
作用
使用数据训练模型
定义
def fit(self,
x=None,
y=None,
batch_size=None,
epochs=1,
verbose=1,
callbacks=None,
validation_split=0.,
validation_data=None,
shuffle=True,
class_weight=None,
sample_weight=None,
initial_epoch=0,
steps_per_epoch=None,
validation_steps=None,
validation_freq=1,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
**kwargs)
输出
一个History对像, History.history表示训练的历史数据,里面包含了loss, metrics, validation loss, validation metrics。
参数
x
输入数据,可以是以下形式
- Numpy数组,或者Numpy数组列表
- 一个Tensorflow张量,或者张量列表
- 指向数组/张量的字典key
- 一个tf.data数据集,返回(inputs,targets)或者(inputs,targets,sample_weights)
- 一个generator或者 keras.utils.Sequence, 返回(inputs,targets)或者(inputs,targets,sample_weights)
y
目标数据(标签数据),类似于输入数据x, 它要么是一个Numpy数组,要么是一个张量。
y与x要保持一致,当x是Numpy数组时,y也要是Numpy数组,当x是张量时,y也要是张量。
当x是数据集、generator、keras.utils.Sequence时,y要为空,因为y会通过x生成。
batch_size
每次更新梯度时,训练的样本数量,默认值为32
当x是符号张量、数据集、generator、keras.utils.Sequence时,不要设置这个参数
epochs
一个epoch表示在整个x和y数据上的所有训练迭代,epochs表示训练最终到达的epoch索引值,并不表示训练次数。
initial_epoch
起始训练epoch索引值,用于唤醒之前的一个训练,当前训练次数=epochs-initial_epoch
verbose
详细模式, 0=静默, 1=进度条, 2=一个epoch一行。
注意记录到文件时,进度条就不是那么有用,建议使用verbose=2.
callbacks
tf.keras.callbacks.* 实例列表,训练的时候会应用列表中的回调方法
validation_split
训练数据与验证数据的比例,取值范围0~1。当x为dataset, generator 或者 keras.utils.Sequence时,不支持这个参数。
validation_data
验证数据,这个参数会覆盖validation_split,它可以是
- 元组(x_val,y_val), x和y是Numpy数组或者张量,必须提供batch_size
- 元组(x_val,y_val,val_sample_weights), x和y是Numpy数组或者张量,必须提供batch_size
- dataset, 必须提供validation_steps
shuffle
当取值是Boolean, 表示是否每个epoch对训练数据洗牌,默认为True
当取值是batch,针对HDF5数据的限制,每个batch洗牌一次。
当steps_per_epoch参数不为空时,shuffle参数不起作用。
class_weight
sample_weight
steps_per_epoch
每个epoch中的步数,(batches of samples)
- x是张量,steps_per_epoch=None时,步数=samples/batch_size。
- x是dataset,steps_per_epoch=None时, epoch会运行直到dataset用尽。
这个参数不支持输入为数组的情况。
validation_steps
仅validation_data不为空时有效,同steps_per_epoch差不多
validation_freq
仅validation_data不为空时有效,取值可以是整数或者集合
- 取值为整数时,表示运行多少次epoch后,执行一次验证
- 取值为集合时,表示在哪些epoch运行后,执行一次验证
max_queue_size
int类型,最大队列大小,仅用于输入为generator 或者 keras.utils.Sequence时。
workers
int类型,工作线程数量,仅用于输入为generator 或者 keras.utils.Sequence时。
use_multiprocessing
Boolean类型,是否使用多线程,仅用于输入为generator 或者 keras.utils.Sequence时。
更多推荐
所有评论(0)