class saver
作用:用于保存和恢复网络的参数
Function
Init
1 |
|
Args:
var_list: 指定保存的变量。可以使list或者dict
max_to_keep: 最近可以保存ckeckpoints的最大数目
keep_checkpoint_every_n_hours: 多久保存一次ckeckpoints1
2
3
4
5
6
7
8
9
10
11
12
13
v1 = tf.Variable(..., name='v1')
v2 = tf.Variable(..., name='v2')
# Pass the variables as a dict:
saver = tf.train.Saver({'v1': v1, 'v2': v2})
# Or pass them as a list.
saver = tf.train.Saver([v1, v2])
# Passing a list is equivalent to passing a dict with the variable op names
# as keys:
saver = tf.train.Saver({v.op.name: v for v in [v1, v2]})
save
1 | save( |
作用:保存网络参数
Args:
sess: 会话句柄
save_path:保存路径
global_step:步数
restore
作用:恢复网络参数
1 | restore( |
Args:
sess: 会话句柄
save_path:保存路径
Example
1 | # Create a saver. |
查看保存文件里的tensor
1 | from tensorflow.python import pywrap_tensorflow |
保存文件:
- .meta文件保存了当前图结构
- .index文件保存了当前参数名
- .data文件保存了当前参数值
注:每调用一次save方法会产生新的文件
To do
- transfer learning 里如何恢复部分模型参数