简单高效的BERT中文文本分类模型开发和部署

准备环境

  • 操作系统:Linux(Mac/windows)
  • TensorFlow Version:1.13.1,动态图模式
  • GPU:12G GPU(CPU跑得很慢)
  • TensorFlow Serving:simple-tensorflow-serving
  • 依赖库:requirements.txt

注意:在windows下运行sh文件,需要安装git工具,在…/Git/bin文件夹中,运行sh.exe后出现窗口,cd到需要运行的文件目录中,输入sh train.sh运行。

目录结构说明

在这里插入图片描述

  • bert是官方源码
  • data是3分类的文本情感分析数据(可直接将.data改为.tsv)
  • train.sh、classifier.py 训练文件(bert中文训练时以字做切分)
  • export.sh、export.py 导出TF serving的模型
  • client.sh、client.py、file_base_client.py 处理输入数据并向部署的TF-serving的模型发出请求,打印输出结果

训练代码

在bert_classifier.py中写一个自定义的Myprocessor类,继承了run_classifier.py中的DataProcessor。
写一个自己的文本处理器,需要注意:

  1. 改写label
  2. 把create_examples改成了共有方法,因为我们后面要调用。
  3. file_base的时候注意跳过第一行,文件数据的第一行是title。
class MyProcessor(DataProcessor):

    def get_test_examples(self, data_dir):
        return self.create_examples(
            self._read_tsv(os.path.join(data_dir, "test.data")), "test")

    def get_train_examples(self, data_dir):
        """See base class."""
        return self.create_examples(
            self._read_tsv(os.path.join(data_dir, "train.data")), "train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self.create_examples(
            self._read_tsv(os.path.join(data_dir, "val.data")), "val")

    def get_pred_examples(self, data_dir):
        return self.create_examples(
            self._read_tsv(os.path.join(data_dir, "pred.data")), "pred")

    def get_labels(self):
        """See base class."""
        return ["-1", "0", "1"]

    def create_examples(self, lines, set_type, file_base=True):
        """Creates examples for the training and dev sets. each line is label+\t+text_a+\t+text_b """
        examples = []
        for (i, line) in tqdm(enumerate(lines)):

            if file_base:
                if i == 0:
                    continue

            guid = "%s-%s" % (set_type, i)
            text = tokenization.convert_to_unicode(line[1])
            if set_type == "test" or set_type == "pred":
                label = "0"
            else:
                label = tokenization.convert_to_unicode(line[0])
            examples.append(
                InputExample(guid=guid, text_a=text, label=label))   # 对于分类任务,单输入单输出,只需要text_a,不需要text_b
        return examples

def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)

  processors = {
      "setiment": MyProcessor
  }
  ...
  1. 其他的训练代码,抄官方的就行
  2. 可以直接运行train.sh,注意修改对应的路径
  3. 生成的ckpt文件在output路径下

导出模型

主要代码如下,生成的pb文件在api文件夹下

def serving_input_receiver_fn():
    input_ids = tf.placeholder(dtype=tf.int64, shape=[None, FLAGS.max_seq_length], name='input_ids')
    input_mask = tf.placeholder(dtype=tf.int64, shape=[None, FLAGS.max_seq_length], name='input_mask')
    segment_ids = tf.placeholder(dtype=tf.int64, shape=[None, FLAGS.max_seq_length], name='segment_ids')
    label_ids = tf.placeholder(dtype=tf.int64, shape=[None, ], name='unique_ids')

    receive_tensors = {'input_ids': input_ids, 'input_mask': input_mask, 'segment_ids': segment_ids,
                       'label_ids': label_ids}
    features = {'input_ids': input_ids, 'input_mask': input_mask, 'segment_ids': segment_ids, "label_ids": label_ids}
    return tf.estimator.export.ServingInputReceiver(features, receive_tensors)

estimator.export_savedmodel(FLAGS.serving_model_save_path, serving_input_receiver_fn)

TensorFlow Serving部署

一键部署:

simple_tensorflow_serving --model_base_path="./api"

本地请求代码

分为两种,一种是读取文件的,就是要预测的文本是tsv文件的,叫做file_base_client.py,另一个直接输入文本的是client.py。首先更改input_fn_builder,返回dataset,然后从dataset中取数据,转换为list格式,传入模型,返回结果。

GitHub 加速计划 / be / bert
37.61 K
9.55 K
下载
TensorFlow code and pre-trained models for BERT
最近提交(Master分支:2 个月前 )
eedf5716 Add links to 24 smaller BERT models. 4 年前
8028c045 - 4 年前
Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐