首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >tensorflow运作方式

tensorflow运作方式

作者头像
用户1733462
发布2018-06-07 15:00:16
发布2018-06-07 15:00:16
6210
举报
文章被收录于专栏:数据处理数据处理

定义变量,初始化,一般初始化随机值,或者常值

代码语言:javascript
复制
weights = tf.Variable(tf.random_normal([784, 200],stddev=0.35),
                      name='weights')
from tensorflow.python.framework import ops
ops.reset_default_graph()

biases = tf.Variable(tf.zeros([200]), name='biases')
init_op = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init_op)
    #print sess.run(weights)

保存变量

代码语言:javascript
复制
from tensorflow.python.framework import ops
#ops.reset_default_graph()
g1 = tf.Graph()
print g1
with g1.as_default():
    # 由另一个变量初始化
    weights = tf.Variable(tf.random_normal([784, 200], stddev=0.35), 
                      name='weights')
    w2 =tf.Variable(weights.initialized_value(), name='w2')
    w_twice = tf.Variable(weights.initialized_value()*0.2,name='w_twice')

# 保存变量
    init_op = tf.global_variables_initializer()

    saver = tf.train.Saver()
with tf.Session(graph=g1) as sess:
    sess.run(init_op)
    print sess.run(weights)
    save_path = saver.save(sess, '/tmp/model.ckpt')
    print 'Model saved in file: ',save_path

恢复变量

代码语言:javascript
复制
#ops.reset_default_graph()
# 恢复变量
g2 = tf.Graph()
with g2.as_default():
    weightss = tf.Variable(tf.zeros([784,200]),name='weights')
    w_2 = tf.Variable(weightss, name='w2')
    w_t = tf.Variable(weightss, name='w_twice')
    print weightss.graph
    saver = tf.train.Saver()
with tf.Session(graph=g2) as sess:
    saver.restore(sess, '/tmp/model.ckpt')
    #print sess.run(weightss)
   # print sess.run(w_2)
    print sess.run(w_t)

保存部分变量

代码语言:javascript
复制
from tensorflow.python.framework import ops
ops.reset_default_graph()
g1 = tf.Graph()
print g1
with g1.as_default():
    # 由另一个变量初始化
    weights = tf.Variable(tf.random_normal([784, 200], stddev=0.35), 
                      name='weights')
    w2 =tf.Variable(weights.initialized_value(), name='w2')
    w_twice = tf.Variable(weights.initialized_value()*0.2,name='w_twice')

# 保存变量
    init_op = tf.global_variables_initializer()

    saver = tf.train.Saver({'my_w2':w2,"my_wt":w_twice})
with tf.Session(graph=g1) as sess:
    sess.run(init_op)
    print sess.run(weights)
    save_path = saver.save(sess, '/tmp/model.ckpt')
    print 'Model saved in file: ',save_path

恢复变量

代码语言:javascript
复制
g2 = tf.Graph()
with g2.as_default():
    w_2 = tf.Variable(tf.zeros([784,200]), name='my_w2')
    w_t = tf.Variable(tf.zeros([784,200]), name='my_wt')
    #weightss = tf.Variable(tf.zeros([784,200]),name='my_weight')
    init_op = tf.global_variables_initializer()
    saver = tf.train.Saver()

with tf.Session(graph=g2) as sess:
    sess.run(init_op)
    saver.restore(sess, '/tmp/model.ckpt')

    print sess.run(w_2)
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2018.05.15 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 恢复变量
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档