https://tensorflow.google.cn/api_docs/python/tf/keras/Model?version=stable#fit

作用

使用数据训练模型

定义

  def fit(self,
          x=None,
          y=None,
          batch_size=None,
          epochs=1,
          verbose=1,
          callbacks=None,
          validation_split=0.,
          validation_data=None,
          shuffle=True,
          class_weight=None,
          sample_weight=None,
          initial_epoch=0,
          steps_per_epoch=None,
          validation_steps=None,
          validation_freq=1,
          max_queue_size=10,
          workers=1,
          use_multiprocessing=False,
          **kwargs)

输出

一个History对像, History.history表示训练的历史数据,里面包含了loss, metrics, validation loss, validation metrics。

参数

x

输入数据,可以是以下形式

  • Numpy数组,或者Numpy数组列表
  • 一个Tensorflow张量,或者张量列表
  • 指向数组/张量的字典key
  • 一个tf.data数据集,返回(inputs,targets)或者(inputs,targets,sample_weights)
  • 一个generator或者 keras.utils.Sequence, 返回(inputs,targets)或者(inputs,targets,sample_weights)

y

目标数据(标签数据),类似于输入数据x, 它要么是一个Numpy数组,要么是一个张量。
y与x要保持一致,当x是Numpy数组时,y也要是Numpy数组,当x是张量时,y也要是张量。
当x是数据集、generator、keras.utils.Sequence时,y要为空,因为y会通过x生成。

batch_size

每次更新梯度时,训练的样本数量,默认值为32
当x是符号张量、数据集、generator、keras.utils.Sequence时,不要设置这个参数

epochs

一个epoch表示在整个x和y数据上的所有训练迭代,epochs表示训练最终到达的epoch索引值,并不表示训练次数。

initial_epoch

起始训练epoch索引值,用于唤醒之前的一个训练,当前训练次数=epochs-initial_epoch

verbose

详细模式, 0=静默, 1=进度条, 2=一个epoch一行。
注意记录到文件时,进度条就不是那么有用,建议使用verbose=2.

callbacks

tf.keras.callbacks.* 实例列表,训练的时候会应用列表中的回调方法

validation_split

训练数据与验证数据的比例,取值范围0~1。当x为dataset, generator 或者 keras.utils.Sequence时,不支持这个参数。

validation_data

验证数据,这个参数会覆盖validation_split,它可以是

  • 元组(x_val,y_val), x和y是Numpy数组或者张量,必须提供batch_size
  • 元组(x_val,y_val,val_sample_weights), x和y是Numpy数组或者张量,必须提供batch_size
  • dataset, 必须提供validation_steps

shuffle

当取值是Boolean, 表示是否每个epoch对训练数据洗牌,默认为True
当取值是batch,针对HDF5数据的限制,每个batch洗牌一次。
当steps_per_epoch参数不为空时,shuffle参数不起作用。

class_weight

sample_weight

steps_per_epoch

每个epoch中的步数,(batches of samples)

  • x是张量,steps_per_epoch=None时,步数=samples/batch_size。
  • x是dataset,steps_per_epoch=None时, epoch会运行直到dataset用尽。
    这个参数不支持输入为数组的情况。

validation_steps

仅validation_data不为空时有效,同steps_per_epoch差不多

validation_freq

仅validation_data不为空时有效,取值可以是整数或者集合

  • 取值为整数时,表示运行多少次epoch后,执行一次验证
  • 取值为集合时,表示在哪些epoch运行后,执行一次验证

max_queue_size

int类型,最大队列大小,仅用于输入为generator 或者 keras.utils.Sequence时。

workers

int类型,工作线程数量,仅用于输入为generator 或者 keras.utils.Sequence时。

use_multiprocessing

Boolean类型,是否使用多线程,仅用于输入为generator 或者 keras.utils.Sequence时。

GitHub 加速计划 / te / tensorflow
184.55 K
74.12 K
下载
一个面向所有人的开源机器学习框架
最近提交(Master分支:2 个月前 )
a49e66f2 PiperOrigin-RevId: 663726708 2 个月前
91dac11a This test overrides disabled_backends, dropping the default value in the process. PiperOrigin-RevId: 663711155 2 个月前
Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐