博主在根据官网配置图像分类迁移学习时,由于没有设置翻墙,程序执行如下语句时

model = image_classifier.create(train_data)

会因为模型下载超时而报错:urllib.error.URLError: <urlopen error [Errno 110] Connection timed out>

博主debug看了下,在/home/sxhlvye/anaconda3/envs/testTF/lib/python3.6/site-packages/tensorflow_examples/lite/model_maker/core/task/model_spec/_init_.py文件里(结合自己的路径)看到了预先设定好的模型配置

切换到/home/sxhlvye/anaconda3/envs/testTF/lib/python3.6/site-packages/tensorflow_examples/lite/model_maker/core/task/model_spec/image_spec.py文件,可以看到每个模型的下载路径

 在不指定模型路径情况相下,系统默认使用的是efficientnet_lite0模型,对应路径是

https://tfhub.dev/tensorflow/efficientnet/lite0/feature-vector/2

博主想直接到tensorflow hub网站上去下载

TensorFlow Hubhttps://tensorflow.google.cn/hub/但点击'查看模型'没有反应

可看到图中示例可以直接通过tensorflow_hub.KerasLayer函数通过路径来加载模型,其实再深入debug,你会发现上面的mage_classifier.create()里面其实也调用了KerasLayer函数

 

为了网页浏览模型,可以访问如下网址:

TensorFlow Hubhttps://hub.tensorflow.google.cn如下页面中可以根据条件去筛选

模型下载不了解决方法

输入上面的网址https://tfhub.dev/tensorflow/efficientnet/lite0/feature-vector/2,发现没有反应,可以对链接进行如下的更改即可。

(1)https://tfhub.dev修改为https://storage.googleapis.com/tfhub-modules

(2) 2修改为2.tar.gz

修改后的访问网址应该为:https://storage.googleapis.com/tfhub-modules/tensorflow/efficientnet/lite0/feature-vector/2.tar.gz

代码修改为如下:

import os
import numpy as np
import tensorflow as tf
assert tf.__version__.startswith('2')

from tflite_model_maker import model_spec
from tflite_model_maker import image_classifier
from tflite_model_maker.config import ExportFormat
from tflite_model_maker.config import QuantizationConfig
from tflite_model_maker.image_classifier import DataLoader
import matplotlib.pyplot as plt
from tensorflow_examples.lite.model_maker.core.task import model_spec as ms
import tensorflow_hub as hub

image_path = tf.keras.utils.get_file(
      'flower_photos.tgz',
      'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
      extract=True)
image_path = os.path.join(os.path.dirname(image_path), 'flower_photos')

data = DataLoader.from_folder(image_path)
train_data, test_data = data.split(0.9)
print(train_data.size)
print(test_data.size)

#inception_v3_spec = image_classifier.ModelSpec(uri='/home/sxhlvye/efficientnet_lite0_feature-vector_2')
inception_v3_spec = image_classifier.ModelSpec(uri='https://storage.googleapis.com/tfhub-modules/tensorflow/efficientnet/lite0/feature-vector/2.tar.gz')
inception_v3_spec.input_image_shape = [240, 240]
model = image_classifier.create(train_data, model_spec=inception_v3_spec)

print("ok")



运行部分结果如下:

=================================================================
Total params: 3,419,429
Trainable params: 6,405
Non-trainable params: 3,413,024
_________________________________________________________________
None
/home/sxhlvye/anaconda3/envs/testTF/lib/python3.6/site-packages/keras/optimizer_v2/optimizer_v2.py:356: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
  "The `lr` argument is deprecated, use `learning_rate` instead.")
Epoch 1/5
2022-04-17 20:08:02.362645: I tensorflow/stream_executor/cuda/cuda_dnn.cc:369] Loaded cuDNN version 8100
103/103 [==============================] - 9s 49ms/step - loss: 0.8722 - accuracy: 0.7630
Epoch 2/5
103/103 [==============================] - 5s 50ms/step - loss: 0.6609 - accuracy: 0.8941
Epoch 3/5
103/103 [==============================] - 5s 50ms/step - loss: 0.6217 - accuracy: 0.9181
Epoch 4/5
103/103 [==============================] - 5s 51ms/step - loss: 0.6085 - accuracy: 0.9190
Epoch 5/5
103/103 [==============================] - 5s 52ms/step - loss: 0.5915 - accuracy: 0.9354
ok

上面图片会默认下载到如下路径(结合自己的博客)

 训练自己数据集用于分类的时候,就可以借鉴此目录结构。

补充:

从 TF Hub 缓存下载的模型  |  TensorFlow Hub

可以看到下载的模型保存的位置

博主把上面文件夹中的内容拷贝别的一个位置

代码中路径设定为本地地址(结合自己的路径),程序可以正常运行

inception_v3_spec = image_classifier.ModelSpec(uri='/home/sxhlvye/efficientnet_lite0_feature-vector_2')

GitHub 加速计划 / te / tensorflow
184.55 K
74.12 K
下载
一个面向所有人的开源机器学习框架
最近提交(Master分支:2 个月前 )
a49e66f2 PiperOrigin-RevId: 663726708 2 个月前
91dac11a This test overrides disabled_backends, dropping the default value in the process. PiperOrigin-RevId: 663711155 2 个月前
Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐