RIPGeo代码理解(五)utils.py( 辅助函数)第一部分-LMLPHP

 代码链接:RIPGeo代码实现

├── lib # 包含模型(model)实现文件
    │        |── layers.py # 注意力机制的代码。
    │        |── model.py # TrustGeo的核心源代码。
    │        |── sublayers.py # layer.py的支持文件。
    │        |── utils.py # 辅助函数。

一、导入常用库和模块

from __future__ import print_function
import numpy as np
import torch
import warnings
import torch.nn as nn
import random
import matplotlib.pyplot as plt
import copy

这段代码首先包含一些导入语句,接着进行一些版本和警告的处理,最后导入了一些常用的库(numpytorchmatplotlib),并定义了一些常用的模块(nnplt)。

1、from __future__ import print_function:这是为了确保代码同时在Python 2和Python 3中都能正常运行。在Python 2中,print是一个语句,而在Python 3中,print()是一个函数。通过这个导入语句,可以在Python 2中使用Python 3风格的print函数。

2、import numpy as np:导入NumPy库,并用np作为别名。NumPy是一个用于科学计算的库,提供了数组等高性能数学运算工具。

3、import torch::导入PyTorch库。PyTorch是一个深度学习框架,提供了张量计算和神经网络搭建等功能。

4、import warnings:导入warnings模块,用于处理警告。

5、import torch.nn as nn:导入PyTorch中的神经网络模块。

6、import random:导入Python的random模块,用于生成伪随机数。

7、import matplotlib.pyplot as plt:导入matplotlib库的pyplot模块,用于绘制图表。

8、import copy:导入Python的copy模块,用于复制对象。

二、warnings.filterwarnings(action='once')

warnings.filterwarnings(action='once')

设置了在使用warnings.filterwarnings时的参数。filterwarnings函数用于配置警告过滤器,以控制哪些警告会被触发,以及如何处理这些警告。

具体来说,action='once'表示警告信息只会被显示一次。这对于一些可能会频繁触发的警告而言是一种控制方式,以避免在控制台或日志中大量重复的警告信息。在第一次触发警告时,它会被显示,但在后续的同类警告中,将不再显示。

请注意,这个配置仅适用于在warnings模块中配置的警告,它并不会影响其他类型的警告或错误。

三、DataPerturb()  数据扰动

RIPGeo代码理解(五)utils.py( 辅助函数)第一部分-LMLPHP

class DataPerturb:
    def __init__(self, eta=1):
        self.eta = eta
        self.loss = torch.nn.MSELoss(reduction='sum')

    def perturb(self, model, data):
        # original
        lm_X, lm_Y, tg_X, tg_Y, lm_delay, tg_delay = data

        # obtain new graph representation
        _, ori_graph_feature = model(lm_X, lm_Y, tg_X,
                                     tg_Y, lm_delay,
                                     tg_delay)

        # add Gaussian data perturb
        new_lm_X, new_lm_Y, new_tg_X, new_tg_Y, new_lm_delay, new_tg_delay = lm_X.clone(), lm_Y.clone(), \
                                                                             tg_X.clone(), tg_Y.clone(), \
                                                                             lm_delay.clone(), tg_delay.clone()
        new_lm_X[:, -16:] += self.eta * torch.normal(0, torch.ones_like(new_lm_X[:, -16:]) * new_lm_X[:, -16:]).cuda()
        new_tg_X[:, -16:] += self.eta * torch.normal(0, torch.ones_like(new_tg_X[:, -16:]) * new_tg_X[:, -16:]).cuda()
        new_lm_delay += self.eta * torch.normal(0, torch.ones_like(new_lm_delay) * new_lm_delay).cuda()
        new_tg_delay += self.eta * torch.normal(0, torch.ones_like(new_tg_delay) * new_tg_delay).cuda()

        # obtain new graph representation
        _, new_graph_feature = model(new_lm_X, new_lm_Y, new_tg_X,
                                     new_tg_Y, new_lm_delay,
                                     new_tg_delay)

        data_loss = self.loss(ori_graph_feature, new_graph_feature)
        return data_loss

这段代码定义了一个名为 DataPerturb 的类,其目的是对给定的数据进行扰动,并计算扰动后的损失。

(一)__init__()

    def __init__(self, eta=1):
        self.eta = eta
        self.loss = torch.nn.MSELoss(reduction='sum')

__init__ 方法中,类初始化时可以指定一个参数 eta,默认为1。该参数用于控制扰动的强度。

损失函数使用MSELoss。

(二)perturb()

    def perturb(self, model, data):
        # original
        lm_X, lm_Y, tg_X, tg_Y, lm_delay, tg_delay = data

        # obtain new graph representation
        _, ori_graph_feature = model(lm_X, lm_Y, tg_X,
                                     tg_Y, lm_delay,
           
03-25 20:31