当用户定义了一个继承自 nn.Module 的神经网络模型,并通过调用 model.forward(input) 进行前向传播时,PyTorch 会根据执行的张量操作序列自动构建并维护一个动态计算图,其中的详细过程是:

  1. 初始化输入: 用户首先准备输入数据作为张量,并将其传递给模型。例如:input_data = torch.randn(batch_size, input_features)

  2. 执行前向传播: 当调用model.forward(input_data)时,模型内部开始按顺序执行各个层的操作。这些操作可能包括但不限于卷积、线性变换(全连接层)、激活函数应用(如ReLU、Sigmoid等)以及池化等。

  3. 动态构建计算图在执行每一步涉及张量的操作时,PyTorch都会记录这个操作及其输入和输出的关系,形成一个节点在网络计算图中。例如,当执行conv_layer(input_data)时,会创建一个新的节点代表卷积运算,该节点的输入是input_data,输出是经过卷积后的张量。

  4. 维护依赖关系: 计算图中的边表示了节点间的依赖关系,即一个节点的输出是另一个节点的输入。这种依赖关系在反向传播过程中至关重要,因为它确定了梯度传播的方向。

  5. 实时更新与销毁: 每次前向传播都根据当前输入数据即时构建一个全新的动态计算图,且在完成一次前向传播和相应的反向传播后,系统不会持久保存这次计算图的具体结构,而是会在下次前向传播时重新构建。

  6. 自动求导与参数更新: 当计算出损失并调用.backward()方法时,PyTorch会沿着动态构建的计算图进行反向传播,自动计算所有参与运算的参数相对于目标变量(通常是损失函数)的梯度,并将这些梯度存储在对应参数的 .grad 属性中。然后,优化器使用这些梯度信息来更新模型参数,实现模型训练的目的。

       在PyTorch中,动态计算图构建过程是一个隐式的过程,它随着代码的执行实时地构建和更新。下面详细描述这个过程并结合具体代码示例:

1. 定义模型(nn.Module)

       首先,用户需要定义一个继承自torch.nn.Module的类来创建神经网络模型。在这个类中,通常会重写__init__()方法初始化所有需要的层,并实现forward()方法定义前向传播逻辑。

Python
1import torch
2from torch import nn
3
4class MyModel(nn.Module):
5    def __init__(self):
6        super(MyModel, self).__init__()
7        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
8        self.relu = nn.ReLU()
9        self.fc = nn.Linear(64 * output_size, num_classes)
10
11    def forward(self, x):
12        # 前向传播步骤
13        x = self.conv1(x)  # 第一个节点:卷积操作
14        x = self.relu(x)   # 第二个节点:ReLU激活函数应用
15        x = x.view(-1, 64 * output_size)  # 可能的第三个节点:数据重塑
16        x = self.fc(x)     # 第四个节点:全连接层操作
17        return x

2. 实例化模型与输入张量

       创建模型实例,并生成或加载一个输入张量(即样本数据),同时确保该张量要求计算梯度。

Python
1model = MyModel()
2input_data = torch.randn(batch_size, 3, image_height, image_width).requires_grad_()

3. 执行前向传播(构建动态计算图)

      当调用 model.forward(input_data) 时,PyTorch开始根据张量运算序列自动构建动态计算图。

Python
1output = model(input_data)
  • 在这个过程中,每一步涉及张量的操作都会被记录下来,并形成计算图中的节点。例如,在上面的代码中:
    • conv1(x) 形成了一个代表卷积操作的节点。
    • relu(x) 形成了一个代表ReLU激活函数应用的节点。
    • fc(x) 形成了一个代表全连接层操作的节点。
    • 节点之间的边则表示了数据依赖关系,也就是从一个节点输出到另一个节点输入的张量流。

4. 计算损失

       计算损失函数,这里假设使用交叉熵损失(CrossEntropyLoss)。

Python
1loss_fn = nn.CrossEntropyLoss()
2target_labels = torch.tensor([...])  # 假设是适当的标签数据
3loss = loss_fn(output, target_labels)

5. 反向传播求梯度

调用 .backward() 方法启动反向传播过程。PyTorch将依据动态计算图自底向上回溯,计算每个参数相对于损失函数的梯度。

 

Python

1loss.backward()

在上述过程中,动态计算图构建的关键在于,它不预先固定网络结构,而是随着程序运行时的具体情况进行构建和销毁。这意味着每次迭代可以处理不同的输入大小、改变网络结构,甚至在某些情况下支持条件分支和循环结构,从而提供了极大的灵活性。

01-31 09:50