Transform用例总结

  1. 该用例调用JIT的setTransform接口,传入pass对IR代码做了一系列优化。
  2. 优化一:fac函数的调用者能直接拿到返回值,不在需要进入fac计算了。
    • 正常函数调用a = fac(5)需要进入fac函数后才能拿到结果120。
    • transform后,a = fac(5)替换为a = 120,编译时将计算前置,提升运行时间。
  3. 优化二:fac函数内的递归调用被拉平了,使用goto在函数内解决,避免了递归函数调用压栈,提升运行时间。

总结:
LLVM(6)ORC实例分析:Transform in cpp-LMLPHP

完整用例

#include "llvm/ExecutionEngine/Orc/LLJIT.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/Pass.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/IPO.h"
#include "llvm/Transforms/Scalar.h"

#include "ExampleModules.h"

using namespace llvm;
using namespace llvm::orc;

ExitOnError ExitOnErr;

const llvm::StringRef MainMod =
    R"(

  define i32 @fac(i32 %n) {
  entry:
    %tobool = icmp eq i32 %n, 0
    br i1 %tobool, label %return, label %if.then

  if.then:                                          ; preds = %entry
    %arg = add nsw i32 %n, -1
    %call_result = call i32 @fac(i32 %arg)
    %result = mul nsw i32 %n, %call_result
    br label %return

  return:                                           ; preds = %entry, %if.then
    %final_result = phi i32 [ %result, %if.then ], [ 1, %entry ]
    ret i32 %final_result
  }

  define i32 @entry() {
  entry:
    %result = call i32 @fac(i32 5)
    ret i32 %result
  }

)";

class MyOptimizationTransform {
public:
  MyOptimizationTransform() : PM(std::make_unique<legacy::PassManager>()) {
    PM->add(createTailCallEliminationPass());
    PM->add(createFunctionInliningPass());
    PM->add(createIndVarSimplifyPass());
    PM->add(createCFGSimplificationPass());
  }

  Expected<ThreadSafeModule> operator()(ThreadSafeModule TSM,
                                        MaterializationResponsibility &R) {
    TSM.withModuleDo([this](Module &M) {
      dbgs() << "--- BEFORE OPTIMIZATION ---\n" << M << "\n";
      PM->run(M);
      dbgs() << "--- AFTER OPTIMIZATION ---\n" << M << "\n";
    });
    return std::move(TSM);
  }

private:
  std::unique_ptr<legacy::PassManager> PM;
};

int main(int argc, char *argv[]) {
  // Initialize LLVM.
  InitLLVM X(argc, argv);

  InitializeNativeTarget();
  InitializeNativeTargetAsmPrinter();

  ExitOnErr.setBanner(std::string(argv[0]) + ": ");

  // (1) Create LLJIT instance.
  auto J = ExitOnErr(LLJITBuilder().create());
  // auto J = ExitOnErr(LLLazyJITBuilder().create());

  // (2) Install transform to optimize modules when they're materialized.
  J->getIRTransformLayer().setTransform(MyOptimizationTransform());

  // (3) Add modules.
  ExitOnErr(J->addIRModule(ExitOnErr(parseExampleModule(MainMod, "MainMod"))));
  // ExitOnErr(J->addLazyIRModule(ExitOnErr(parseExampleModule(MainMod, "MainMod"))));

  // (4) Look up the JIT'd function and call it.
  auto EntryAddr = ExitOnErr(J->lookup("entry"));
  auto *Entry = EntryAddr.toPtr<int()>();

  int Result = Entry();
  outs() << "--- Result ---\n"
         << "entry() = " << Result << "\n";

  return 0;
}
10-17 13:30