这篇文章探讨一下如何优雅地保存训练好的TensorFlow模型,使得服务端能够方便地调用模型。

一个简单的方法

首先介绍一个最简单的方法将模型保存为.pb文件。 如果你使用的TensorFlow版本>=1.6的话,可以使用这种方法。 下面给出示例代码。

其中使用到了tf.saved_model.simple_save这个函数,该函数在TensorFlow >= 1.6才有提供。 export_dir变量表示导出模型文件夹路径。模型保存后,export_dir文件夹下的结构如下所示:

1
2
3
4
5
.
+-- saved_model.pb
+-- variables/
|   +-- variables.data-?????-of-?????
|   +-- variables.index

其中,saved_model.pb保存计算图,variables/*保存网络变量。

使用这个模型的方法可见另一篇文章

那么,如果使用的TensorFlow版本<1.6,或者想要更多地控制模型的导出行为,该怎么办呢? 针对第一种情况,其实可以简单地把simple_save这个函数按照tf.saved_model.simple_save的写法自己写一遍来调用; 而对于第二种,则需要了解更多保存细节。

更多的控制

如何将图的变量转为常量保存?