【Kuiperinfer】笔记04 计算图的定义-LMLPHP

本节总览

这一小节需要了解Kuiperinfer项目使用的模型格式PNNX,以及其使用的计算图相关的结构定义。

然后对计算图进行封装。

PNNX计算图

KuiperInfer使用的模型格式为PNNX,是Pytorch Neural Network Exchange的缩写。这个格式是NCNN框架的一个pr中引入的,该格式希望能够将Pytorch模型导出为高效、简洁的计算图。详细可参考https://zhuanlan.zhihu.com/p/427620428

回顾一下绪论中的内容,概念上的计算图主要包含以下部分:

  1. Operator:计算图中的计算节点
  2. Graph:多个Operator串联得到的有向无环图,规定了计算节点的执行顺序
  3. Layer:Operator中具体执行运算的部分,首先读入输入Tensor中的数据,然后计算,将结果存放在输出Tensor中
  4. Tensor:存放多维数据

而PNNX计算图包括Graph, OperatorOperand三种结构。下面将逐个分析每个结构的类定义。

测试模型

Kuiperinfer提供了一对小模型文件linear.paramlinear.bin,分别为网络的结构和权重文件。网络的pytorch定义如下:

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear = nn.Linear(32, 128)
    def forward(self, x):
        x = self.linear(x)
        x = F.sigmoid(x)
        return x

可以使用Netron打开linear.param文件观察网络的结构

【Kuiperinfer】笔记04 计算图的定义-LMLPHP

Linear层的结构,包含两个操作数#0, #1和两个权重属性@weight, @bias

【Kuiperinfer】笔记04 计算图的定义-LMLPHP

Graph

Graph的功能是管理计算图中的OperatorOperand;其结构定义如下

class Graph
{
    // 算子
    Operator* new_operator(const std::string& type, const std::string& name);
    Operator* new_operator_before(const std::string& type, const std::string& name, const Operator* cur);
    Operator* new_operator_after(const std::string& type, const std::string& name, const Operator* cur);
    
    // 操作数
    Operand* new_operand(const torch::jit::Value* v);
    Operand* new_operand(const std::string& name);
    Operand* get_operand(const std::string& name);
    
    std::vector<Operator*> ops;
    std::vector<Operand*> operands;
    
    // 加载/保存
    int load(const std::string& parampath, const std::string& binpath);
    int save(const std::string& parampath, const std::string& binpath);
};

单元测试 - 输出算子

TEST(test_ir, pnnx_graph_ops) {
  using namespace kuiper_infer;
  /**
   * 如果这里加载失败,请首先考虑相对路径的正确性问题
   */
  // LOG(INFO) << std::filesystem::current_path(); // 需要# include<filesystem>, 加这一行检查路径
  std::string bin_path("../../../src/model_file/test_linear.pnnx.bin"); // 这里要填远程主机上的模型相对路径
  std::string param_path("../../../src/model_file/test_linear.pnnx.param");
  std::unique_ptr<pnnx::Graph> graph = std::make_unique<pnnx::Graph>();
  int load_result = graph->load(param_path, bin_path);
  // 如果这里加载失败,请首先考虑相对路径(bin_path和param_path)的正确性问题
  ASSERT_EQ(load_result, 0);
  const auto &ops = graph->ops;
  for (int i = 0; i < ops.size(); ++i) {
    LOG(INFO) << ops.at(i)->name;
  }
}
I20240301 08:55:21.320151  1057 test_ir.cpp:37] pnnx_input_0
I20240301 08:55:21.320165  1057 test_ir.cpp:37] linear
I20240301 08:55:21.320171  1057 test_ir.cpp:37] F.sigmoid_0
I20240301 08:55:21.320178  1057 test_ir.cpp:37] pnnx_output_0

可以看到,Graph中存在四个算子:input_0, linear, sigmoid_0, output_0

单元测试 - 输出操作数

TEST(test_ir, pnnx_graph_operands) {
  using namespace kuiper_infer;
  /**
   * 如果这里加载失败,请首先考虑相对路径的正确性问题
   */
  std::string bin_path("../../../src/model_file/test_linear.pnnx.bin");
  std::string param_path("../../../src/model_file/test_linear.pnnx.param");
  std::unique_ptr<pnnx::Graph> graph = std::make_unique<pnnx::Graph>();
  int load_result = graph->load(param_path, bin_path);
  // 如果这里加载失败,请首先考虑相对路径(bin_path和param_path)的正确性问题
  ASSERT_EQ(load_result, 0);
  const auto &ops = graph->ops;
  for (int i = 0; i < ops.size(); ++i) {
    const auto &op = ops.at(i); // 取算子
    LOG(INFO) << "OP Name: " << op->name;
    LOG(INFO) << "OP Inputs";
    for (int j = 0; j < op->inputs.size(); ++j) {
      LOG(INFO) << "Input name: " << op->inputs.at(j)->name
                << " shape: " << ShapeStr(op->inputs.at(j)->shape);
    }

    LOG(INFO) << "OP Output";
    for (int j = 0; j < op->outputs.size(); ++j) {
      LOG(INFO) << "Output name: " << op->outputs.at(j)->name
                << " shape: " << ShapeStr(op->outputs.at(j)->shape);
    }
    LOG(INFO) << "---------------------------------------------";
  }
}
I20240301 08:55:21.320255  1057 test_ir.cpp:56] OP Name: pnnx_input_0
I20240301 08:55:21.320262  1057 test_ir.cpp:57] OP Inputs
I20240301 08:55:21.320267  1057 test_ir.cpp:63] OP Output
I20240301 08:55:21.320274  1057 test_ir.cpp:65] Output name: 0 shape: 1 x 32
I20240301 08:55:21.320281  1057 test_ir.cpp:68] ---------------------------------------------
I20240301 08:55:21.320287  1057 test_ir.cpp:56] OP Name: linear
I20240301 08:55:21.320292  1057 test_ir.cpp:57] OP Inputs
I20240301 08:55:21.320298  1057 test_ir.cpp:59] Input name: 0 shape: 1 x 32
I20240301 08:55:21.320305  1057 test_ir.cpp:63] OP Output
I20240301 08:55:21.320362  1057 test_ir.cpp:65] Output name: 1 shape: 1 x 128
I20240301 08:55:21.320374  1057 test_ir.cpp:68] ---------------------------------------------
I20240301 08:55:21.320380  1057 test_ir.cpp:56] OP Name: F.sigmoid_0
I20240301 08:55:21.320386  1057 test_ir.cpp:57] OP Inputs
I20240301 08:55:21.320392  1057 test_ir.cpp:59] Input name: 1 shape: 1 x 128
I20240301 08:55:21.320405  1057 test_ir.cpp:63] OP Output
I20240301 08:55:21.320420  1057 test_ir.cpp:65] Output name: 2 shape: 1 x 128
I20240301 08:55:21.320432  1057 test_ir.cpp:68] ---------------------------------------------
I20240301 08:55:21.320443  1057 test_ir.cpp:56] OP Name: pnnx_output_0
I20240301 08:55:21.320453  1057 test_ir.cpp:57] OP Inputs
I20240301 08:55:21.320461  1057 test_ir.cpp:59] Input name: 2 shape: 1 x 128
I20240301 08:55:21.320467  1057 test_ir.cpp:63] OP Output
I20240301 08:55:21.320478  1057 test_ir.cpp:68] ---------------------------------------------

Operator

Operator的结构定义如下:

class Operator
{
public:
    // 输入输出操作数
    std::vector<Operand*> inputs; 
    std::vector<Operand*> outputs;
    
    // Operator的类型和名称
    std::string type;
    std::string name;
    
    std::vector<std::string> inputnames;
    std::map<std::string, Parameter> params; 
    std::map<std::string, Attribute> attrs; 
}

单元测试

尝试输出Linear层的信息,并与可视化结果进行对照

TEST(test_ir, pnnx_graph_operands_and_params) {
  using namespace kuiper_infer;
  /**
   * 如果这里加载失败,请首先考虑相对路径的正确性问题
   */
  std::string bin_path("../../../src/model_file/test_linear.pnnx.bin");
  std::string param_path("../../../src/model_file/test_linear.pnnx.param");
  std::unique_ptr<pnnx::Graph> graph = std::make_unique<pnnx::Graph>();
  int load_result = graph->load(param_path, bin_path);
  // 如果这里加载失败,请首先考虑相对路径(bin_path和param_path)的正确性问题
  ASSERT_EQ(load_result, 0);
  const auto &ops = graph->ops;
  for (int i = 0; i < ops.size(); ++i) {
    const auto &op = ops.at(i);
    if (op->name != "linear") { // 只打印Linear
      continue;
    }
    LOG(INFO) << "OP Name: " << op->name;
    LOG(INFO) << "OP Inputs";
    for (int j = 0; j < op->inputs.size(); ++j) {
      LOG(INFO) << "Input name: " << op->inputs.at(j)->name
                << " shape: " << ShapeStr(op->inputs.at(j)->shape);
    }

    LOG(INFO) << "OP Output";
    for (int j = 0; j < op->outputs.size(); ++j) {
      LOG(INFO) << "Output name: " << op->outputs.at(j)->name
                << " shape: " << ShapeStr(op->outputs.at(j)->shape);
    }

    LOG(INFO) << "Params";
    for (const auto &attr : op->params) {
      LOG(INFO) << attr.first << " type " << attr.second.type;
    }

    LOG(INFO) << "Weight: ";
    for (const auto &weight : op->attrs) {
      LOG(INFO) << weight.first << " : " << ShapeStr(weight.second.shape)
                << " type " << weight.second.type;
    }
    LOG(INFO) << "---------------------------------------------";
  }
}
I20240302 12:06:46.516649  2026 test_ir.cpp:90] OP Name: linear
I20240302 12:06:46.516660  2026 test_ir.cpp:91] OP Inputs
I20240302 12:06:46.516680  2026 test_ir.cpp:93] Input name: 0 shape: 1 x 32
I20240302 12:06:46.516691  2026 test_ir.cpp:97] OP Output
I20240302 12:06:46.516698  2026 test_ir.cpp:99] Output name: 1 shape: 1 x 128
I20240302 12:06:46.516705  2026 test_ir.cpp:103] Params
I20240302 12:06:46.516714  2026 test_ir.cpp:105] bias type 1
I20240302 12:06:46.516724  2026 test_ir.cpp:105] in_features type 2
I20240302 12:06:46.516733  2026 test_ir.cpp:105] out_features type 2
I20240302 12:06:46.516741  2026 test_ir.cpp:108] Weight:
I20240302 12:06:46.516749  2026 test_ir.cpp:110] bias : 128 type 1
I20240302 12:06:46.516758  2026 test_ir.cpp:110] weight : 128 x 32 type 1
I20240302 12:06:46.516765  2026 test_ir.cpp:113] ---------------------------------------------

【Kuiperinfer】笔记04 计算图的定义-LMLPHP

Parameter 和 Attribute

权重数据结构和参数数据结构的定义如下,可参考Operator中单元测试的结果进行对照;

class Parameter
{
public:
    Parameter()
        : type(0)
    {
    }
    ...
#if BUILD_PNNX
    Parameter(const torch::jit::Node* value_node);
    Parameter(const torch::jit::Value* value);
#endif // BUILD_PNNX

    static Parameter parse_from_string(const std::string& value);

    // 0=null 1=b 2=i 3=f 4=s 5=ai 6=af 7=as 8=others
    int type;

    // value
    bool b;
    int i;
    float f;
    std::vector<int> ai;
    std::vector<float> af;

    // keep std::string typed member the last for cross cxxabi compatibility
    std::string s;
    std::vector<std::string> as;
};
class Attribute
{
public:
    Attribute()
        : type(0)
    {
    }

#if BUILD_PNNX
    Attribute(const at::Tensor& t);
#endif // BUILD_PNNX

    Attribute(const std::initializer_list<int>& shape, const std::vector<float>& t);

    // 0=null 1=f32 2=f64 3=f16 4=i32 5=i64 6=i16 7=i8 8=u8 9=bool
    int type;
    std::vector<int> shape;

    std::vector<char> data;
};

Operand

Operand的定义如下:

class Operand
{
public:
    void remove_consumer(const Operator* c);

    Operator* producer; // 产生该操作数的算子,只能有一个
    std::vector<Operator*> consumers; // 使用该操作数的算子,可以有多个

    // 0=null 1=f32 2=f64 3=f16 4=i32 5=i64 6=i16 7=i8 8=u8 9=bool 10=cp64 11=cp128 12=cp32
    int type;
    std::vector<int> shape;

    // keep std::string typed member the last for cross cxxabi compatibility
    std::string name;

    std::map<std::string, Parameter> params;

};

单元测试

TEST(test_ir, pnnx_graph_operands_customer_producer) {
  using namespace kuiper_infer;
  /**
   * 如果这里加载失败,请首先考虑相对路径的正确性问题
   */
  std::string bin_path("../../../src/model_file/test_linear.pnnx.bin");
  std::string param_path("../../../src/model_file/test_linear.pnnx.param");
  std::unique_ptr<pnnx::Graph> graph = std::make_unique<pnnx::Graph>();
  int load_result = graph->load(param_path, bin_path);
  // 如果这里加载失败,请首先考虑相对路径(bin_path和param_path)的正确性问题
  ASSERT_EQ(load_result, 0);
  const auto &operands = graph->operands;
  for (int i = 0; i < operands.size(); ++i) {
    const auto &operand = operands.at(i);
    LOG(INFO) << "Operand Name: #" << operand->name;
    LOG(INFO) << "Customers: ";
    for (const auto &customer : operand->consumers) {
      LOG(INFO) << customer->name;
    }

    LOG(INFO) << "Producer: " << operand->producer->name;
  }
}
I20240302 12:06:46.516847  2026 test_ir.cpp:131] Operand Name: #0
I20240302 12:06:46.516860  2026 test_ir.cpp:132] Customers:
I20240302 12:06:46.516870  2026 test_ir.cpp:134] linear
I20240302 12:06:46.516880  2026 test_ir.cpp:137] Producer: pnnx_input_0
    
I20240302 12:06:46.516891  2026 test_ir.cpp:131] Operand Name: #1
I20240302 12:06:46.516901  2026 test_ir.cpp:132] Customers:
I20240302 12:06:46.516911  2026 test_ir.cpp:134] F.sigmoid_0
I20240302 12:06:46.516922  2026 test_ir.cpp:137] Producer: linear
    
I20240302 12:06:46.516932  2026 test_ir.cpp:131] Operand Name: #2
I20240302 12:06:46.516942  2026 test_ir.cpp:132] Customers:
I20240302 12:06:46.516953  2026 test_ir.cpp:134] pnnx_output_0
I20240302 12:06:46.516964  2026 test_ir.cpp:137] Producer: F.sigmoid_0

可以结合模型的可视化进行观察:操作数#0pnnx_input_0产生,在linear层使用;操作数#1linear产生,在F.sigmoid_0层使用,操作数#2同理。

【Kuiperinfer】笔记04 计算图的定义-LMLPHP

Kuiperinfer封装

首先给出UML结构图

【Kuiperinfer】笔记04 计算图的定义-LMLPHP

可知Kuiperinfer的计算图核心结构是RuntimeOperator,该结构是对PNNX::Operator的封装。

定义如下:

struct RuntimeOperator {
  virtual ~RuntimeOperator();

  bool has_forward = false;
  std::string name;      /// 计算节点的名称
  std::string type;      /// 计算节点的类型
  std::shared_ptr<Layer> layer;  /// 节点对应的计算Layer

  std::vector<std::string> output_names;  /// 节点的输出节点名称
  std::shared_ptr<RuntimeOperand> output_operands;  /// 节点的输出操作数

  std::map<std::string, std::shared_ptr<RuntimeOperand>>
      input_operands;  /// 节点的输入操作数
  std::vector<std::shared_ptr<RuntimeOperand>>
      input_operands_seq;  /// 节点的输入操作数,顺序排列
  std::map<std::string, std::shared_ptr<RuntimeOperator>>
      output_operators;  /// 输出节点的名字和节点对应

  std::map<std::string, RuntimeParameter*> params;  /// 算子的参数信息
  std::map<std::string, std::shared_ptr<RuntimeAttribute>>
      attribute;  /// 算子的属性信息,内含权重信息
};
  • name:节点的名称,用来区分唯一节点,例如Conv_1, Conv_2

  • type:节点类型,例如Convolution, Relu

  • layer:完成具体计算的组件,例如在卷积算子中的卷积计算

  • input_operands, output_operands:运算符的输入输出操作数,若输入为4, 3, 224, 224,则input_operandsdatas数组长度为4,数组中每个元素张量的大小为3, 224, 224

    struct RuntimeOperand {
      // 假设输入为BCHW = (4, 3, 224, 224)
      std::string name;                                     /// 操作数的名称
      std::vector<int32_t> shapes;                          /// 操作数的形状
      std::vector<std::shared_ptr<Tensor<float>>> datas;    /// 存储操作数,长度为4,每个Tensor形状为(3, 224, 224)
      RuntimeDataType type = RuntimeDataType::kTypeUnknown; /// 操作数的类型,一般是float
    };
    
  • params:算子的参数信息,例如卷积层的核大小、步长

  • attribute:算子的weights,bias

从PNNX::Operator到Kuiper::RuntimeOperator

PNNX::Operator中提取信息,填入Kuiperinfer对应的数据结构。

实现如下,首先给出计算图的类定义:

class RuntimeGraph {
public:
  /**
   * 初始化计算图
   * @param param_path 计算图的结构文件
   * @param bin_path 计算图中的权重文件
   */
  RuntimeGraph(std::string param_path, std::string bin_path);

  /**
   * 设置权重文件
   * @param bin_path 权重文件路径
   */
  void set_bin_path(const std::string &bin_path);

  /**
   * 设置结构文件
   * @param param_path  结构文件路径
   */
  void set_param_path(const std::string &param_path);

  /**
   * 返回结构文件
   * @return 返回结构文件
   */
  const std::string &param_path() const;

  /**
   * 返回权重文件
   * @return 返回权重文件
   */
  const std::string &bin_path() const;

  /**
   * 计算图的初始化
   * @return 是否初始化成功
   */
  bool Init();

  const std::vector<std::shared_ptr<RuntimeOperator>> &operators() const;

private:
  /**
   * 初始化kuiper infer计算图节点中的输入操作数
   * @param inputs pnnx中的输入操作数
   * @param runtime_operator 计算图节点
   */
  static void InitGraphOperatorsInput(
      const std::vector<pnnx::Operand *> &inputs,
      const std::shared_ptr<RuntimeOperator> &runtime_operator);

  /**
   * 初始化kuiper infer计算图节点中的输出操作数
   * @param outputs pnnx中的输出操作数
   * @param runtime_operator 计算图节点
   */
  static void InitGraphOperatorsOutput(
      const std::vector<pnnx::Operand *> &outputs,
      const std::shared_ptr<RuntimeOperator> &runtime_operator);

  /**
   * 初始化kuiper infer计算图中的节点属性
   * @param attrs pnnx中的节点属性
   * @param runtime_operator 计算图节点
   */
  static void
  InitGraphAttrs(const std::map<std::string, pnnx::Attribute> &attrs,
                 const std::shared_ptr<RuntimeOperator> &runtime_operator);

  /**
   * 初始化kuiper infer计算图中的节点参数
   * @param params pnnx中的参数属性
   * @param runtime_operator 计算图节点
   */
  static void
  InitGraphParams(const std::map<std::string, pnnx::Parameter> &params,
                  const std::shared_ptr<RuntimeOperator> &runtime_operator);

public:
private:
  std::string input_name_;  /// 计算图输入节点的名称
  std::string output_name_; /// 计算图输出节点的名称
  std::string param_path_;  /// 计算图的结构文件
  std::string bin_path_;    /// 计算图的权重文件

  std::vector<std::shared_ptr<RuntimeOperator>> operators_;
  std::map<std::string, std::shared_ptr<RuntimeOperator>> operators_maps_;

  std::unique_ptr<pnnx::Graph> graph_; /// pnnx的graph
};

计算图的初始化

bool RuntimeGraph::Init() {
    if (this->bin_path_.empty() || this->param_path_.empty()) {
        LOG(ERROR) << "The bin path or param path is empty";
        return false;
    } // 检查权重文件

    this->graph_ = std::make_unique<pnnx::Graph>(); // make函数,用来将参数转发给动态分配对象,然后返回对象的智能指针
    int load_result = this->graph_->load(param_path_, bin_path_); // 调用PNNX的load读取参数
    if (load_result != 0) {
        LOG(ERROR) << "Can not find the param path or bin path: " << param_path_
                   << " " << bin_path_;
        return false;
    } // 检查读取结果

    std::vector<pnnx::Operator *> operators = this->graph_->ops; // 获取PNNX的算子列表
    if (operators.empty()) {
        LOG(ERROR) << "Can not read the layers' define";
        return false;
    }

    this->operators_.clear(); 
    this->operators_maps_.clear(); // 主动释放vector内存
    for (const pnnx::Operator *op: operators) { // 处理每个算子
        if (!op) {
            LOG(ERROR) << "Meet the empty node";
            continue;
        } else {
            std::shared_ptr<RuntimeOperator> runtime_operator =
                    std::make_shared<RuntimeOperator>();
            // 初始化算子的名称
            runtime_operator->name = op->name;
            runtime_operator->type = op->type;

            // 初始化算子中的input
            const std::vector<pnnx::Operand *> &inputs = op->inputs;
            if (!inputs.empty()) {
                InitGraphOperatorsInput(inputs, runtime_operator);
            }

            // 记录输出operand中的名称
            const std::vector<pnnx::Operand *> &outputs = op->outputs;
            if (!outputs.empty()) {
                InitGraphOperatorsOutput(outputs, runtime_operator);
            }

            // 初始化算子中的attribute(权重)
            const std::map<std::string, pnnx::Attribute> &attrs = op->attrs;
            if (!attrs.empty()) {
                InitGraphAttrs(attrs, runtime_operator);
            }

            // 初始化算子中的parameter
            const std::map<std::string, pnnx::Parameter> &params = op->params;
            if (!params.empty()) {
                InitGraphParams(params, runtime_operator);
            }
            this->operators_.push_back(runtime_operator); // 完成所有初始化后,插入到operators数组中
            this->operators_maps_.insert({runtime_operator->name, runtime_operator}); // 插入到map中
        }
    }

    return true;
}

实现中Operand,Attribute和Param的初始化方法见下文。

初始化RuntimeOperator中的RuntimeOperand

提取PNNX中的操作数到RuntimeOperand中,对应实现为InitGraphOperatorsInputInitGraphOperatorsOutput

for (const pnnx::Operator *op : operators){
    inputs = op->inputs;
    InitGraphOperatorsInput(inputs, runtime_operator);
    ...
}
void RuntimeGraph::InitGraphOperatorsInput(
        const std::vector<pnnx::Operand *> &inputs,
        const std::shared_ptr<RuntimeOperator> &runtime_operator) {
    for (const pnnx::Operand *input: inputs) { // 根据pnnx的每个操作数,初始化RuntimeOperator中的每个操作数
        if (!input) {
            continue;
        }
        const pnnx::Operator *producer = input->producer;
        std::shared_ptr<RuntimeOperand> runtime_operand =
                std::make_shared<RuntimeOperand>();
        runtime_operand->name = producer->name;
        runtime_operand->shapes = input->shape; // 设置name,shape

        switch (input->type) { // 设置type
            case 1: {
                runtime_operand->type = RuntimeDataType::kTypeFloat32;
                break;
            }
            case 0: {
                runtime_operand->type = RuntimeDataType::kTypeUnknown;
                break;
            }
            default: {
                LOG(FATAL) << "Unknown input operand type: " << input->type;
            }
        }
        runtime_operator->input_operands.insert({producer->name, runtime_operand}); // 记录producer,即产生这个操作数的运算符
        runtime_operator->input_operands_seq.push_back(runtime_operand); // 顺序排列
    }
}
const std::vector<pnnx::Operand*>& outputs = op->outputs;
InitGraphOperatorsOutput(outputs, runtime_operator);  
void RuntimeGraph::InitGraphOperatorsOutput(
        const std::vector<pnnx::Operand *> &outputs,
        const std::shared_ptr<RuntimeOperator> &runtime_operator) {
    for (const pnnx::Operand *output: outputs) {
        if (!output) {
            continue;
        }
        const auto &consumers = output->consumers; // 获取cusumer,即输出的操作数
        for (const auto &c: consumers) {
            runtime_operator->output_names.push_back(c->name); // 记录cusumer的名字,输出操作数在这里不进行构建
        }
    }
}

初始化RuntimeAttribute

const std::map<std::string, pnnx::Attribute>& attrs = op->attrs;
InitGraphAttrs(attrs, runtime_operator);
void RuntimeGraph::InitGraphAttrs(
    const std::map<std::string, pnnx::Attribute>& attrs,
    const std::shared_ptr<RuntimeOperator>& runtime_operator) {
  for (const auto& [name, attr] : attrs) {
    switch (attr.type) {
      case 1: {
        std::shared_ptr<RuntimeAttribute> runtime_attribute =
            std::make_shared<RuntimeAttribute>();
        runtime_attribute->type = RuntimeDataType::kTypeFloat32;
        runtime_attribute->weight_data = attr.data;
        runtime_attribute->shape = attr.shape; // 记录信息
        runtime_operator->attribute.insert({name, runtime_attribute}); // 插入名字和信息,在Linear中,name就是weight或bias
        break;
      }
      default: {
        LOG(FATAL) << "Unknown attribute type: " << attr.type;
      }
    }
  }
}

初始化RuntimeParam

void RuntimeGraph::InitGraphParams(
        const std::map<std::string, pnnx::Parameter> &params,
        const std::shared_ptr<RuntimeOperator> &runtime_operator) {
    for (const auto &[name, parameter]: params) { // 分别处理每个param
        const int type = parameter.type;
        switch (type) { // 按不同类别处理
            // 在这里写出实现
            // 0=null 1=b 2=i 3=f 4=s 5=ai 6=af 7=as 8=others,对应分case即可
            case int(RuntimeParameterType::kParameterUnknown) : {
                RuntimeParameter* runtime_parameter = new RuntimeParameter;
                runtime_operator->params.insert({ name, runtime_parameter });
                break;
            }
            case int(RuntimeParameterType::kParameterBool) : { // bool,需要赋值
                RuntimeParameterBool* runtime_parameter = new RuntimeParameterBool;
                runtime_parameter->value = parameter.b;
                runtime_operator->params.insert({ name, runtime_parameter });
                break;

            }
            case int(RuntimeParameterType::kParameterInt) : {
                RuntimeParameterInt* runtime_parameter = new RuntimeParameterInt;
                runtime_parameter->value = parameter.i;
                runtime_operator->params.insert({ name, runtime_parameter });
                break;
            }
            case int(RuntimeParameterType::kParameterFloat) : {
                RuntimeParameterFloat* runtime_parameter = new RuntimeParameterFloat;
                runtime_parameter->value = parameter.f;
                runtime_operator->params.insert({ name, runtime_parameter });
                break;
            }
            case int(RuntimeParameterType::kParameterString) : {
                RuntimeParameterString* runtime_parameter = new RuntimeParameterString;
                runtime_parameter->value = parameter.s;
                runtime_operator->params.insert({ name, runtime_parameter });
                break;
            }
            case int(RuntimeParameterType::kParameterIntArray) : {
                RuntimeParameterIntArray* runtime_parameter = new RuntimeParameterIntArray;
                runtime_parameter->value = parameter.ai;
                runtime_operator->params.insert({ name, runtime_parameter });
                break;
            }
            case int(RuntimeParameterType::kParameterFloatArray) : {
                RuntimeParameterFloatArray* runtime_parameter = new RuntimeParameterFloatArray;
                runtime_parameter->value = parameter.af;
                runtime_operator->params.insert({ name, runtime_parameter });
                break;
            }
            case int(RuntimeParameterType::kParameterStringArray) : {
                RuntimeParameterStringArray* runtime_parameter = new RuntimeParameterStringArray;
                runtime_parameter->value = parameter.as;
                runtime_operator->params.insert({ name, runtime_parameter });
                break;
            }
            default: {
                LOG(FATAL) << "Unknown parameter type: " << type;
            }
        }
    }
}

单元测试

【Kuiperinfer】笔记04 计算图的定义-LMLPHP

TEST(test_ir, pnnx_graph_all_homework) {
  using namespace kuiper_infer;
  /**
   * 如果这里加载失败,请首先考虑相对路径的正确性问题
   */
  std::string bin_path("../../../src/model_file/test_linear.pnnx.bin");
  std::string param_path("../../../src/model_file/test_linear.pnnx.param");
  RuntimeGraph graph(param_path, bin_path);
  const bool init_success = graph.Init();
  ASSERT_EQ(init_success, true);
  const auto &operators = graph.operators();
  for (const auto &operator_ : operators) { // 检查operator
    if (operator_->name == "linear") {
      const auto &params = operator_->params;
      ASSERT_EQ(params.size(), 3); // 三个参数,分别为bias,in_features,out_features
        /
      ASSERT_EQ(params.count("bias"), 1);
      RuntimeParameter *parameter_bool = params.at("bias");
      ASSERT_NE(parameter_bool, nullptr);
      ASSERT_EQ((dynamic_cast<RuntimeParameterBool *>(parameter_bool)->value),
                true); 
      /
      ASSERT_EQ(params.count("in_features"), 1);
      RuntimeParameter *parameter_in_features = params.at("in_features");
      ASSERT_NE(parameter_in_features, nullptr);
      ASSERT_EQ(
          (dynamic_cast<RuntimeParameterInt *>(parameter_in_features)->value),
          32);

      /
      ASSERT_EQ(params.count("out_features"), 1);
      RuntimeParameter *parameter_out_features = params.at("out_features");
      ASSERT_NE(parameter_out_features, nullptr);
      ASSERT_EQ(
          (dynamic_cast<RuntimeParameterInt *>(parameter_out_features)->value),
          128);
    }
  }
}

参考

  • 【Kuiperinfer】:https://github.com/zjhellofss/kuiperdatawhale
  • 作者B站主页:https://space.bilibili.com/1822828582?spm_id_from=333.337.search-card.all.click
  • 【Armadillo Docs】:https://arma.sourceforge.net/docs.html
  • 【PNNX】:https://github.com/pnnx/pnnx
03-10 12:15