Github主页:https://linxid.github.io/
知乎:https://zhuanlan.zhihu.com/p/35775368
CSDN:https://blog.csdn.net/linxid/article/details/79973258
1.计算图
首先解释什么是计算图,了解TensorFlow的计算模型.
和我们常见的程序计算框架不同,并不是赋值,或者计算后,TensorFlow立马完成这些操作,而是将这些操作,赋予在一个图中,这个图可以简单地认为是我们常见的流程图,只不过更加的详细,包括每一步操作(op)和变量的名字.
因为神经网络的运行,需要很大的计算量,如果循环往复的赋值计算,将会带来很大的麻烦,所以TensorFlow采用计算图模型,包括其他很多框架都是采用同样的模型.
比如说,我们建立一个神经网络,那我们首先,定义好变量(参数),然后每一步计算,包括输出.当定以好之后,我们建立一个会话(Session),在这个会话中运行图模型.
1 | # tf.get_default_graph()获取默认的计算图 |
2. TensorFlow中的数据类型
2.1 tf.constant():常量
TensorFlow中的张量模型,也就是我们所说的Tensor的意思.
在TensorFlow中完成的是一步赋值过程,将定义的一个张量保存在变量a或b中.a和b保存的不是张量的值,而是这些数字的运算过程.
注意:
- 1.张量三属性:名字,维度,类型(一般为tf.float32);
- 2.运算时,维度和类型要保持一直;
- 3.名字给出了张量的唯一标识符,同事给出变量是如何计算的;
- 4.张量的计算和普通计算是不同的,保存的不是数值,而是计算过程!!!!
张量的作用:
- 1.对中间计算结果的引用,增加代码的可读性
- 2.计算图构造完成后,用来获得计算结果
1 | import tensorflow as tf |
2.2 tf.Variable()
专门用来,保存和更新神经网络中的参数,所以我们在建立神经网络模型参数的时候,使用tf.Variable()类型.
常用的常数生成函数:
- tf.zeros()
- tf.ones()
- tf.fill()
常用的随机数生成函数:
- tf.random_normal(维度,标准差,类型)
- tf.truncated_normal()
- tf.random_uniform()
1 | w1 = tf.Variable(tf.random_normal([2,3], stddev=1, seed=1)) |
2.3 tf.placeholder()
在前面定义张量时,我们使用的是tf.constant().我们都知道神经网络的训练需要上千上万,甚至几百万的循环迭代,而每次循环迭代,每次生成一个常量,TensorFlow就会在计算图中增加一个节点,最后就会使得,图异常复杂.所以TensorFlow引入tf.placeholder机制.
tf.placeholder相当于提前为变量,定义,开拓了一个位置,这个位置的数据在程序运行时被指定.也就是在计算图中,我们只有一个位置,这个位置的值会不断改变,这样就避免了重复生成.
三个属性:
- 类型
- 维度
- 名字
运行时,使用feed_dict来定义输入.1
2x = tf.placeholder(tf.float32, shape=(None, 2), name='x-input')
y_ = tf.placeholder(tf.float32, shape=(None, 1), name='y-input')
3.会话(Session)
前面我们已经提到了计算图模型,在TensorFlow中,我们使用会话(Session)来执行定义好的计算.会话拥有并管理TensorFlow中的所有资源,运行结束后,需要及时回收资源.
3.1 模式一:
调用会话生成函数,然后关闭会话函数.1
2
3
4
5
6
7
8
9
10sess = tf.Session()
with sess.as_default():
print(result2.eval())
print('\n')
# 这两个有同样的功能
print(sess.run(result2))
print('\n')
print(result2.eval(session=sess))
print('\n')
3.2 模式二:
利用上下文管理器,来使用会话,自动释放资源.1
2
3
4
5
6
7
8
9
10
11
12
with tf.Session(graph=g1) as sess:
tf.global_variables_initializer().run()
with tf.variable_scope('',reuse=True):
print(sess.run(tf.get_variable('v')))
with tf.Session(graph=g2) as sess:
tf.global_variables_initializer().run()
with tf.variable_scope('', reuse=True):
print(sess.run(tf.get_variable('v')))
print('\n')
3.3 模式三:
交互式会话模式,尤其是在Ipython(Jupyter Notebook),这种交互式的环境中,使用非常方便.1
2
3
4# 使用交互式会话模式
sess1 = tf.InteractiveSession()
print(sess1.run(result2))
sess1.close()
4.TensorFlow实现神经网络
1 | # 导入tensorflow库 |