树状数组(BIT)—— 一篇就够了

前言、内容梗概

本文旨在讲解:

Fenwick便开始思考

树状数组的经典例题及其技巧

模板题:单点修改,区间查询

思路:

非常简单,只需要套模板即可。

代码:

// 上述模板部分省略
using ll = long long;
const int maxn = 1e6+50;
ll f[maxn];
int main(){
    ios::sync_with_stdio(0);
    cin.tie(0);

    int n; cin >> n;
    int q; cin >> q;
    for (int i = 1; i <= n; ++ i) cin >> f[i];
    BIT<ll> bit(f, n);

    for (int i = 0; i < q; ++ i){
        int type; cin >> type;
        if (type == 1){
            int i, x;
            cin >> i >> x;
            bit.add(i, (ll) x);
        }else {
            int l, r;
            cin >> l >> r;
            cout << bit.sum(l, r) << '\n';
        }
    }

    return 0;
}

模板题:区间修改,区间查询

思路:

该模板题则难上许多,需要对问题分析建模。

我们需要考虑如何建模表示 \(tree\) 数组。

首先,设更新操作为:在 \([l, r]\) 上增加 \(x\)。我们考虑如何建模维护新的区间前缀和 \(c^{\prime}[i]\)。

下面分情况讨论:

  1. \(i < l\)

这种情况下,不需要任何处理, \(c^{\prime}[i] = c[i]\)

  1. \(l <= i <= r\)

这种情况下,\(c^{\prime}[i] = c[i] + (i - l + 1) \cdot x\)

  1. \(i > r\)

这种情况下,\(c^{\prime}[i]=c[i] + (r-l+1)\cdot x\)

因此如下图所示,我们可以设两个 BIT,那么\(c^{\prime}[i] = \mathrm{sum(bit_1,i)+sum(bit_2,i) \cdot i}\),对于区间修改等价于:

  • 在 \(bit_1\) 的 \(l\) 位置加上 \(-x(l-1)\),在 \(bit_1\) 的 \(r\) 位置加上 \(rx\)。
  • 在 \(bit_2\) 的 \(l\) 位置加上 \(x\) 的 \(r\) 位置加上 \(-x\)。

树状数组(BIT)—— 一篇就够了-LMLPHP

代码

#include <bits/stdc++.h>
using namespace std;
// 模板代码省略
// 这里做的是单点查询,但是实现的为区间查询
using ll = long long;
ll get_sum(BIT<ll> &a, BIT<ll> &b, int l, int r){
    auto sum1 = a.sum(r) * r + b.sum(r);
    auto sum2 = a.sum(l - 1) * (l - 1) + b.sum(l - 1);
    return sum1 - sum2;
}


int n, q;   
const int maxn = 1e6 + 50;
ll f[maxn];
int main(){
    // ios::sync_with_stdio(0);
    // cin.tie(0);
    
    cin >> n >> q;
    BIT<ll> bit1, bit2;
    for (int i = 1; i <= n; ++ i) cin >> f[i];
    bit1.init(n), bit2.init(f, n);

    for (int i = 0; i < q; ++ i){
        int type; cin >> type;
        if (type == 1){
            int l, r, x;
            cin >> l >> r >> x;
            bit2.add(l, (ll) -1 * (l - 1) * x), bit2.add(r + 1, (ll) r * x);
            bit1.add(l, (ll) x), bit1.add(r + 1, (ll) -1 * x);
        }else {
            int i; cin >> i;
            cout << get_sum(bit1, bit2, i, i) << '\n';
        }
    }
    return 0;
}

逆序对 简单版

思路

BIT 求解逆序对是非常方便的,在初学时我没有想到过 BIT 还能用于求解逆序对。在这里我借逆序对来引出一个小技巧:离散化

BIT 求逆序对的方法非常简单,逆序对指:i < j and a[i] > a[j],统计逆序对实际上就是统计在该元素 a[i] 之前有多少元素大于他。

我们可以初始化一个大小为 \(maxn\) 的空 BIT(全为0)。随后:

  1. 我们顺序访问数组中的每个元素 a[i] ,计算区间 [1, a[i]] 的和,更新答案 ans = i - sum([1, a[i]])
  2. 然后,我们更新 BIT 中坐标 a[i] 的值,tree[a[i]] <- tree[a[i]] + 1

举个例子:

eg: [2,1,3,4]
BIT: 0, 0, 0, 0
>2, sum(2) = 0, ans += 0 - sum(2) -> ans = 0
BIT: 0, 1, 0, 0
>1, sum(1) = 0, ans += 1 - sum(1) -> ans = 1
BIT: 1, 1, 0, 0
>3, sum(3) = 2, ans += 2 - sum(3) -> ans = 1
BIT: 1, 1, 1, 0
>4, sum(4) = 3, ans += 3 - sum(4) -> ans = 1

实际上,便是借助 BIT 高效计算前缀和的性质实现了快速打标记,先统计在我之前有多少个标记(这些都是合法对),再将自己所在位置的标记加 \(1\)。

因此,很容易写出这段代码:

代码一

// 仅保留核心代码
int reversePairs(vector<int>& nums) {
    int n = nums.size();
    if (n == 0) return 0;
    int mx = *max_element(nums.begin(), nums.end()); 
    BIT<int> bit(mx); // 因为最大只到最大值的位置
    int ans(0);
    for (int i = 0; i < n; ++ i){
        ans += (i - bit.sum(nums[i]));
        bit.add(nums[i], 1);
    }
    return ans;
}

但是这个代码有非常严重的问题,首先假如 mx = 1e9 就会出现段错误;或者假如 nums[i] < 0 则会出现访问越界的问题,但是实际上题目中说明了:数组最多只有 50000个元素,也就是我们需要想办法将坐标离散化,保留其大小顺序即可。

代码二

#define lb lower_bound
#define all(x) x.begin(), x.end()
const int maxn = 5e4 + 50;
struct node{
    int v, id;
}f[maxn]; // 离散化结构体
int arr[maxn];
bool cmp(const node&a, const node &b){
    return a.v < b.v;
}
class Solution {
public:
    int reversePairs(vector<int>& nums) {
        int n = nums.size();
        if (n == 0) return 0;
        BIT<int> bit(n);

        for (int i = 1; i <= n; ++ i){
            f[i].v = nums[i - 1], f[i].id = i; // 赋值用于排序
        }
        sort(f + 1, f + 1 + n, cmp); 
        int cnt = 1, i = 1;
        while (i <= n){
            /* 用于去重,当有相同元素时其对应的 cnt 应该相同 */
            if (f[i].v == f[i - 1].v || i == 1) arr[f[i].id] = cnt;
            else arr[f[i].id] = ++cnt;
            ++ i;
        }

        int ans = 0;
        for (int i = 0; i < n; ++ i){
            int pos = arr[i + 1];
            ans += i - bit.sum(pos);
            bit.add(pos, 1);
        }
        return ans;
    }
};

上面的方法是离散化操作的一种方式,有一点复杂,需要注意的细节比较多。

实际上,该方法便是通过保留每个元素的所在位置,并将其排序,排序后自己在第 \(i\) 个则将其值 arr[id] = i 离散化为 \(i\) 。这样既可以避免负数,过大的数造成的访问或者内存错误,也充分的保留了各元素之间的大小关系。

离散化的复杂度为 \(\mathcal{O(\log n)}\) ,实际上也就是排序的复杂度。

可以发现,结构体方法对于空间要求较大,且在去重方面需要下功夫,稍后我们会讲解另一种离散化方法,你也可以试试用后文的离散化方法再次解决这题。

逆序对加强版: 翻转对

思路

可以看到这题与逆序对的区别在于,翻转对的定义是:i < ja[i] > 2*a[j] 。其大小关系发生了变化,不再是原来单纯的大小关系,而存在值的变化。

我们可以思考下能否用结构体进行离散化,简单思考后发现:假如第 i 个元素离散化之后的编号为 id1 ,则我们无法确定编号为 2 * id1 所对应元素的 val 值之间的关系。可能出现如下情况:

id1 = 1, val = 2
2 * id1 = 2, val' = 3

所以,我们需要思考一个新的方法来进行离散化。需要注意的是,我们的关键点在于:如何快速的询问一个元素在一个数组中是第几大的元素。比如,在数组中快速询问某个值的两倍是第几大的。

实际上,稍微有基础的话答案便非常清晰:二分查找,我们可以首先将数组进行排序,利用 \(lower_{bound}\) 快速找到第一个大于等于该元素所对应的位置,用代码来说的话:pos = lower_bound(nums.begin(), nums.end(), x) - nums.begin() + 1

eg: nums = [3, 2, 4, 7]
farr = sort(nums) -> farr = [2, 3, 4, 7]
pos(4) = lower_bound(..., 4) - farr.begin() + 1 = 3
便可以快速找到 4 的编号为 3 (1-index)

但是,有一个问题需要注意:

eg: nums = [3, 2, 5, 7]
farr = sort(nums) -> farr = [2, 3, 5, 7]
pos(4) = lower_bound(...,4) - farr.beign() + 1 = 3
但实际上,5 > 4,这次询问错误了!!!

为什么会出现询问错误的情况呢?(因此我们需要找到的是最后一个小于等于元素 x 的对应位置,而二分查找是大于等于 x 的第一个元素,当原数组中不存在 x 时,便会出现询问出错的情况。)

有多种方法可以解决这个问题,但是最为方便的还是直接将需要查询的元素全部加进去,也就是 2 * x 全部添加到数组中,从而保证一定存在该元素,又因为 lower_bound 的性质,我们无需去重。

代码

using vi = vector<int>;
using vl = vector<ll>;
#define complete_unique(x) (x.erase(unique(x.begin(), x.end()), x.end()))
#define lb lower_bound
class Solution {
public:
    int reversePairs(vector<int>& nums) {

        vl tarr;
        for (auto &e: nums){
            tarr.push_back(e);
            tarr.push_back(2ll * e); // 直接把需要离散化的对应元素加入
        }

        sort(tarr.begin(), tarr.end());
        int n = nums.size();
        BIT<int> bit(2 * n); // 注意,因为加入了两倍的元素,所以对应也要开大一点
        int res = 0;

        for (int i = 0; i < n; ++ i){
            res += i - bit.sum(lb(tarr.begin(), tarr.end(), 2ll * nums[i]) - tarr.begin() + 1);
            bit.add(lb(tarr.begin(), tarr.end(), nums[i]) - tarr.begin() + 1, 1);
        }
        return res;
    }
};

二维BIT:区间查询,单点修改

思路

二维 BIT 实际上就是套娃,一层层套即可。

其复杂度为 \(\mathcal{O(\log n \times \log m)}\) ,\(n,m\)分别为每个维度 BIT 的个数,这里不再赘述。

代码

#include <bits/stdc++.h>
using namespace std;
// 模板代码省略
using ll = long long;
int n, m, q;   
const int maxn = 5e3 + 50;
BIT<ll> f[maxn]; // 二维BIT

void add(int i, int j, ll x){
    while (i <= n){
        f[i].add(j, x);
        i += lowbit(i);
    }
}
ll sum(int i, int j){
    ll res(0);
    while (i > 0){
        res += f[i].sum(j);
        i -= lowbit(i);
    }
    return res;
}

signed main(){
    ios::sync_with_stdio(0);
    cin.tie(0);

    cin >> n >> m;
    for (int i = 1; i <= n; ++ i) f[i] = BIT<ll>(m);

    int type;
    while (cin >> type){
        if (type == 1){
            int x, y, k; cin >> x >> y >> k;
            add(x, y, (ll) k);
        }else {
            int a, b, c, d; cin >> a >> b >> c >> d;
            cout << sum(c, d) - sum(c, b - 1) - sum(a - 1, d) + sum(a - 1, b - 1) << '\n';
        }
    }
    
    return 0;
}

后记

这是我耗时最长的一篇博客,也是我花费心血最多的一次,也希望自己能好好掌握 BIT

附上参考链接:

10-16 07:27