写在前面

张量广播是什么意思?有什么作用?看了这篇文章或许能给你一些启发!


写在中间

一、简单介绍

二、函数讲解

三、代码示例

  1. 首先创建两个简单的二维张量
import tensorflow as tf
tensor1 = tf.ones([4, 3])
tensor2 = tf.ones([4, 3])
print(tensor1)
print(tensor2)
print(tensor1 + tensor2)

输出结果

tf.Tensor(
[[1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]], shape=(4, 3), dtype=float32)
tf.Tensor(
[[1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]], shape=(4, 3), dtype=float32)
tf.Tensor(
[[2. 2. 2.]
 [2. 2. 2.]
 [2. 2. 2.]
 [2. 2. 2.]], shape=(4, 3), dtype=float32)

  1. 接着创建两个维度相同,但形状不同的张量
import tensorflow as tf
tensor1 = tf.ones([3, 4])
tensor2 = tf.ones([4])
print(tensor1)
print(tensor2)

tensor2 = tf.broadcast_to(tensor2, tensor1.shape)
print(tensor1 + tensor2)

输出结果

tf.Tensor(
[[1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]], shape=(3, 4), dtype=float32)
tf.Tensor([1. 1. 1. 1.], shape=(4,), dtype=float32)
tf.Tensor(
[[2. 2. 2. 2.]
 [2. 2. 2. 2.]
 [2. 2. 2. 2.]], shape=(3, 4), dtype=float32)


四、错误广播

那么何种情况不能进行广播呢?

  1. 当两个张量在某个维度上的形状既不相等,也不为1时,无法进行广播。例如,一个张量的形状是 (3, 4),另一个张量的形状是 (2, 4),这两个形状在第一个维度上既不相等,也不为1,因此无法进行广播。

  2. 当两个张量的维度数量不匹配时,无法进行广播。例如,一个张量的形状是 (3, 4),另一个张量的形状是 (3, 4, 2),这两个形状的维度数量不同,因此无法进行广播。

写在最后

06-28 07:10