樹狀數組-從入門到拓展詳解

_Aking 2021-09-19 14:13:44 阅读数:601

拓展

樹狀數組-從入門到拓展

樹狀數組入門

期間如有問題,歡迎評論區討論

樹狀數組是一個可以在O(log2n)的時間複雜度下實現修改和查詢的數據結構,因此對於我們在競賽中起著重要作用

為了能够直觀的認識這個時間複雜的意義,我們看下面這個問題

給定長度為n的序列

如果要求我們求出下標區間l-r內數的總和,我們可能會想到直接兩個前綴和想减即可

如果我要求把第k個數修改一下,那我們直接修改即可

但是,重點來了,如果我給出m個詢問,一共分為兩種詢問

第一種是需要你對第k個數進行修改

第二種是需要你對當前區間l-r求和,那麼,還可以直接算嗎?

很顯然是不行的,例如我有八個數,我要求2-7之間的區間和,我可以sum[7] - sum[1],如果我下一步是讓你第二個數的值加上x,那麼後面的這些前綴和就都需要重新算一邊,當我們詢問次數高達十的五次方次的時候,顯示這種暴力的方法是行不通的

好,帶著問題我們開始接觸樹狀數組,首先看一下樹狀數組

假設我們給出八個數(a[1]、a[2]、....a[8])、那麼我們定義樹狀數組tr

tr[1] = a[1];

tr[2] = a[1] + a[2];

tr[3] = a[3];

tr[4] = a[1] + a[2] + a[3] + a[4];

tr[5] = a[5];

tr[6] = a[5] + a[6];

tr[7] = a[7];

tr[8] = a[1] + a[2] + a[3] + a[4] + a[5] + a[6] + a[7] + a[8];

我們使奇數項直接等於我們的原數組,偶數項則是一種和的形式,為了方便理解,下面放上樹狀數組的一個圖形

空白方格不需要管,我們只需要知道,箭頭代錶我當前方格管理的方格有哪些,例如tr[6]就可以管理a[5] + a[6]

為什麼像這樣定義呢

當我們需要求出前六項的前綴和的時候,我們想一下,我們只需要tr[6] + tr[4]就可以得到了

當我們需要求出前七項的前綴和的時候,我們只需要tr[7] + tr[6] + tr[4]即可

再者我們考慮修改,假設我需要修改a[2],其實我們發現,包含a[2]的只有tr[2]、tr[4]、tr[8]這三項

這樣我們在邊修改邊查詢的時候,時間複雜度就會降低很多

那麼現在我們考慮,這6應該跟4、6關聯起來了呢,我們看一下4、6的二進制

6 :110

4: 100

再看一下前七項前綴和需要用到的7、6、4的

7 : 111

6 : 110

4 : 100

規律就是,每次將二進制中最低比特的1减去,直到减完即可

對於修改,我們考慮2和2、4、8的關系

2 : 10

4 : 100

8 : 1000

再考慮包括3的有哪些,3、4、8

3 : 11

4 : 100

8 : 1000

規律就是每次加上二進制中最低比特的1,直到超過n

二進制最低比特的1也有響應的算法,我們稱之為lowbit函數

int lowbit(int x)// 返回x的最低比特1
{
return x & -x;
}

這裏是利用了負數在計算機內存儲形式為補碼的特點,感興趣的可以自己計算一下

單點修改、區間查詢

了解了樹狀數組的內容,和lowbit函數,接下來就是如何實現單點修改和區間查詢了

對於單點修改,我們上面提到過,從該點開始,每次加上lowbit,直到最大

這樣我們就把可以管理到我們當前數的tr數組給初始化完成了

例如a[2] = 2;那麼我就需要把tr[2]、tr[4]、tr[8]都加上這個a[2],因為他們都可以管理到a[2]

// 單點修改
void update(int x, int c) // a[x] = c;
{
for (int i = x;i <= n; i += lowbit(i)) tr[i] += c;// 如果你可以管理到我,那麼你就加上我的值
}

代碼很短,多琢磨琢磨就清楚了

考慮前x個數的和,根據我們上面分析的,每次减去最低比特的1即可

例如我要求前七個數的和,就是res += tr[7] + tr[6] + tr[4];

// 區間查詢
int getsum(int x) // 返回前x個數的和
{
int res = 0;// 前綴和計算結果,用於返回
for (int i = x;i > 0; i -= lowbit(i)) res += tr[i];
return res;
}

下面推薦一道LibreOJ上的例題,推薦這個OJ是因為在這上面提交後是可以看到任何一個測試點的,方便調試

#130. 樹狀數組 1 :單點修改,區間查詢

AC代碼如下

#include <iostream>
#include <algorithm>
#include <cstring>
using namespace std;
typedef long long ll;
const int N = 1e6 + 10;
int a[N];// 原數組
ll tr[N];// 樹狀數組
int n, m;
int lowbit(int x)// lowbit函數
{
return x & -x;
}
void update(int x, int c) // a[x] = c
{
for (int i = x;i <= n; i += lowbit(i)) tr[i] += c;
}
ll getsum(int x) // 返回前x個數的和
{
ll res = 0;
for (int i = x;i > 0; i -= lowbit(i)) res += tr[i];
return res;
}
int main() {
ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
// freopen("in.in", "r", stdin);freopen("out.out", "w", stdout);
cin >> n >> m;
for (int i = 1;i <= n;i ++) {
cin >> a[i];
update(i, a[i]);// 初始化tr數組
}
for (int i = 1;i <= m;i ++) {
int id, l, r;
cin >> id >> l >> r;
if (id == 1) {// 修改
update(l, r);// 在l的比特置上加r,並更新後面可以管理到他的
}else {// 查詢
ll ans = getsum(r) - getsum(l - 1);// 區間和
cout << ans << endl;
}
}
return 0;
}

區間修改、區間查詢

說明:線段樹掛懶標記可更簡單的解决,如果實在不想看樹狀數組的,可以跳過這裏

既然單獨提出來了,那麼一定是有特點的,不能簡簡單單考慮,我存的時候存差分不久可以了嗎

仔細想想,如果存的時候存的是差分數組,那麼你只能進行簡單的單點查詢,為了能够實現區間查詢,我們就要好好考慮一下了

首先對於一個區間和a[1]+a[2]+...+a[n]

定義c為差分數組,那麼我們可以得到a[1]+a[2]+...+a[n] = (c[1]) + (c[1]+c[2]) + ... + (c[1]+c[2]+...+c[n])

我們進一步將公式轉換= n*c[1] + (n-1)*c[2] +... +c[n]

最終我們得到求和公式= n * (c[1]+c[2]+...+c[n]) - (0*c[1]+1*c[2]+...+(n-1)*c[n])

接下來就可以開始愉快的敲代碼了

我們只需要維護兩個樹狀數組c1、c2,其中c1存我們的差分數組,c2存我們的差分數組*系數

推薦題目依舊是LibreOJ上的模板題

#132. 樹狀數組 3 :區間修改,區間查詢

AC代碼如下

#include <iostream>
#include <algorithm>
#include <cstring>
using namespace std;
typedef long long ll;
const int N = 1e6 + 10;
int n, q;
ll a[N], c1[N], c2[N];
void update(ll *h, int x, ll y){
while(x <= n){
h[x] += y;
x += x & -x;
}
}
ll getsum(ll *h, int x){
ll res = 0;
while(x){
res += h[x];
x -= x & -x;
}
return res;
}
int getsum(int x)//求差分數組c1的前綴和->修改後的原數組
{
int ans = 0;
while (x)
{
ans += c1[x];
x -= lowbit(x);
}
return ans;
}
int main(int argc, char const *argv[])
{
// freopen("in.txt", "r", stdin);
// freopen("out.txt", "w", stdout);
cin >> n >> q;
for (int i = 1; i <= n; ++i)
{
cin >> a[i];
update(c1, i, a[i] - a[i - 1]);// c1存 差分
update(c2, i, (i - 1) * (a[i] - a[i - 1]));// c2存 差分*系數
}
for(int i = 1;i <= q;i++){
ll id, l, r, x;
cin >> id >> l >> r;
if(id == 1){
cin >> x;
update(c1, l, x);
update(c1, r + 1, -x);
update(c2, l, x * (l - 1));
update(c2, r + 1, -x * r);
}else{
ll sum1 = (l - 1) * getsum(c1, l - 1) - getsum(c2, l - 1);// 求和公式
ll sum2 = r * getsum(c1, r) - getsum(c2, r);
cout << sum2 - sum1 << endl;
}
}
// for (int i = 1; i <= n; i++)//修改後的原數列、俗稱單點查詢qwq
// cout << getsum(i) << ' ';
return 0;
}

區間最值說明

不再對其進行講解,自行理解即可,很簡潔,最大值最小值思路一樣

void update(int x, ll y)
{
while (x <= n)
{
c[x] = max(c[x], y);// 誰可以管理到我,誰就對我取max,看我是否可以作為最大值
x += x & -x;
}
}
ll getsum(int x)
{
ll ans = 0;
while (x)
{
ans = max(ans, c[x]);// 對於每一個我可以管理的到,取max
x -= x & -x;
}
return ans;
}

拓展一:逆序數問題

逆序數問題只做為興趣,實際情况不一定有遞歸求的快,下面我也會給出遞歸版本的求逆序數

逆序數就是當前數前面有幾個比他大的數

那麼我們用樹狀數組就可以完美的解决這個問題,因為我們樹狀數組就可以求出了一個數前面的前綴和,如果我們每個數都貢獻1,那麼求的就是有幾個數比我小,然後剩下的,就是比我大的了,就可以做為貢獻加入結果中

當然大部分情况下,數會非常大,所以需要離散化一下

#include <iostream>
#include <algorithm>
using namespace std;
const int N = 1e6 + 10;
struct node
{
int x, y;
} que[N];
int tree[N], dispers[N];
int n, cnt = 0, sum;
bool cmp(node a, node b) { return a.x < b.x; }
void update(int x)
{
while (x <= n)
{
tree[x]++;// 默認貢獻為1代錶個數
x += x & -x;
}
}
int getsum(int x)
{
int ans = 0;
while (x)
{
ans += tree[x];
x -= x & -x;
}
return ans;
}
int main()
{
// freopen("in.txt", "r", stdin);
while (cin >> n)
{
for (register int i = 1; i <= n; i++)
{
cin >> que[i].x;
que[i].y = i;// 用於離散化
}
sort(que + 1, que + n + 1, cmp);
for (register int i = 1; i <= n; i++)// 離散化
{
cnt++;
// cout << cnt << endl;
dispers[que[i].y] = cnt;
}
for (register int i = 1; i <= cnt; i++)
{
update(dispers[i]);// 默認貢獻為1,代錶個數
sum += (i - getsum(dispers[i]));// 那麼剩餘的就是比當前數大的了
}
cout << sum << endl;
}
return 0;
}

遞歸版本代碼

原理也比較簡單,在遞歸排序中,我們知道,是通過一個數一個數往前挪的

那麼,對於一個數,你在遞歸排序中,挪了幾次,都加起來,就是這個序列的逆序數了

#include <iostream>
#define INF 0xFFFFF
using namespace std;
long long A[1000000];
long long number;
typedef long long ll;
void Merge(int left, int mid, int right)
{
int len1 = mid - left + 1;
int len2 = right - mid;
int L[len1 + 2], R[len2 + 2];
for (int i = 1; i <= len1; i++)
L[i] = A[left + i - 1];
for (int i = 1; i <= len2; i++)
R[i] = A[mid + i];
L[len1 + 1] = R[len2 + 1] = INF;
int l = 1, r = 1;
for (int i = left; i <= right; i++)
{
if (L[l] <= R[r])
A[i] = L[l++];
else
{
A[i] = R[r++];
number += len1 - l + 1;// 如果需要往前挪,就讓逆序數加上你挪了幾比特
}
}
}
void mergeSort(int left, int right)
{
if (left < right)
{
int mid = (left + right) / 2;
mergeSort(left, mid);
mergeSort(mid + 1, right);
Merge(left, mid, right);
}
}
int main()
{
// freopen("in.txt", "r", stdin);
std::ios::sync_with_stdio(false);
long long n;
while (cin >> n)
{
number = 0;
for (register int i = 1; i <= n; i++)
cin >> A[i];
mergeSort(1, n);
cout << number << endl;
// for(register int i=0;i<n;i++) cout << A[i];
// cout << endl;
}
}

拓展二:上昇子序列問題

推薦例題AcWing上的3662. 最大上昇子序列和

子序列問題大部分是需要dp來求解的

不過用樹狀數組也有奇效

通過樹狀數組的性質,我們知道,對於每個樹狀數組的含義是管理他前面是數,那麼我們就可以不只用來求和,用來求最大值也是可以的

對於本題,首先我們修改一下update和getsum函數

void update(int x, ll y)
{
while (x <= n)
{
c[x] = max(c[x], y);// 誰可以管理到我,誰就對我取max,看我是否可以作為最大值
x += x & -x;
}
}
ll getsum(int x)
{
ll ans = 0;
while (x)
{
ans = max(ans, c[x]);// 對於每一個我可以管理的到,取max
x -= x & -x;
}
return ans;
}

我們定義sum數組錶示以第i個數結尾的最大上昇子序列和,對於我們的序列,我們一個一個判斷

對於當前這個數我們考慮比他小的所有數裏面,最大的子序列和,並實時更新樹狀數組

由於我們n只有十的五次方,但是每個數卻可能非常大,所以需要離散化一下

AC代碼如下

#include <iostream>
#include <algorithm>
#include <cstring>
#include <cmath>
#include <cstdio>
#include <map>
#include <unordered_map>
#define INF 0xFFFFFF
using namespace std;
typedef long long ll;
const int N = 1e5 + 10;
ll n, tot = 0;
ll que[N], disperse[N], sum[N], amxsum[N], c[N];
unordered_map<int, ll> mp;
void update(int x, ll y)
{
while (x <= n)
{
c[x] = max(c[x], y);
x += x & -x;
}
}
ll getsum(int x)
{
ll ans = 0;
while (x)
{
ans = max(ans, c[x]);
x -= x & -x;
}
return ans;
}
int main()
{
// freopen("in.txt", "r", stdin);
// freopen("out.txt", "w", stdout);
cin >> n;
for (int i = 1; i <= n; i++)
cin >> que[i];
memcpy(disperse, que, sizeof que);// disperse數組用於存放原數組來離散化
sort(disperse + 1, disperse + n + 1);
for (int i = 1; i <= n; i++)
if (!mp.count(disperse[i]))
mp[disperse[i]] = ++tot;// 每個數是第幾大的數
for (int i = 1; i <= n; i++)// 從頭往後遍曆
{
// 對於在我前面出現並且比我小的數裏面,找子序列和最大的,然後加上我的值
// 由於我是第mp[que[i]]大的數,那麼我的子序列和最小也得是mp[que[i]]
sum[i] = max(mp[que[i]], getsum(mp[que[i]] - 1) + que[i]);
update(mp[que[i]], sum[i]);//邊記錄還要變更新以我為結尾最大的子序列和,方便後面的數進行判斷
}
ll ans = 0;
for (int i = 1; i <= n; i++)
ans = max(ans, sum[i]);
cout << ans << endl;
return 0;
}

這裏是求的最大上昇子序列和問題,相應的還有最長上昇子序列問題,只需要將這裏代碼改一下就行了

for (int i = 1; i <= n; i++)// 從頭往後遍曆
{
sum[i] = max(1ll, getsum(mp[que[i]] - 1) + 1);
update(mp[que[i]], sum[i]);
// sum[i] = max(mp[que[i]], getsum(mp[que[i]] - 1) + que[i]);
// update(mp[que[i]], sum[i]);
}

如果要求非嚴格上昇的話,也只需要把getsum(mp[que[i]] - 1)修改為getsum(mp[que[i]])即可

拓展三:第k小數

推薦題目AcWing244. 謎一樣的牛

思路:二分+樹狀數組

對於中間某一頭牛,我們只知道他前面比他高的,並不清楚他後面有幾個

但是,對於最後一頭牛,如果他前面有k個比他高的,那麼他一定是第k + 1個高的牛

對於倒數第二頭牛,如果他前面有k個比他高的,那麼他一定是除了最後一頭牛以外的,第k + 1個高的牛

圖示

對於第5頭牛,我已經可以確定,他是第1高的,說明他已經占據了第一個比特置,那麼看第4頭牛

因為他前面有一個比它高的,所以我們從1-n進行二分,看那個數前面有1個還存在的高度,然後我們定比特到第4頭牛的高度為3

看第3頭牛,他前面有兩個比它高的,從1-n進行二分,我們定比特到5這個高度的前面還有兩個存在的高度,所以我們定比特到第三頭牛高度為5

以此類推

所以我們就可以從後往前遍曆,每求出一頭牛是第幾高,我們就將這個高度删去,然後去判斷下一頭牛

#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;
const int N = 1e5 + 10;
int a[N], c[N];
int n;
vector<int> res;
inline int lowbit(int x) {
return x & -x;
}
inline void update(int x, int cc) {
for(int i = x;i <= n;i += lowbit(i)) {
c[i] += cc;
}
}
inline int get(int x) {
int res = 0;
for(int i = x;i;i -= lowbit(i)) {
res += c[i];
}
return res;
}
int main(){
scanf("%d", &n);
for(int i = 2;i <= n;i++) {
scanf("%d", &a[i]);
}
for (int i = 1;i <= n;i ++) {
update(i, 1);// 最開始,每一個高度都存在
}
for(int i = n;i;i--){
int l = 1, r = n;
while(l < r){
int mid = l + r >> 1;
if(get(mid) >= a[i] + 1) r = mid;
else l = mid + 1;
}
update(r, -1);// 這個高度已經有牛了,將其删去
res.push_back(r);
}
for(auto i = res.rbegin();i != res.rend();i++){// 由於是倒著求答案的,所以要倒著輸出
printf("%d\n", *i);
}
return 0;
}

拓展四:離散化

推薦題目

HDUTuring Tree

T組數據,給定長度為n的序列,m次詢問,每次詢問區間內不同數相加的和

為了能從左往右進行枚舉詢問,以每個區間右端點進行排序

然後遍曆序列裏的n個數

如果這個數在之前出現過,那麼我將之前出現的删去,然後在這個比特置加上該數,因為每個數只能貢獻一次

然後用while循環詢問的右端點,是否有右端點與我們遍曆到的帶你重合了,如果有,那麼這個區間裏的數我一定已經初始化好了,然後去求這個區間

能看到這裏我就不囉嗦這個右端點合理性了,這是完全正確的

#include <iostream>
#include <cstring>
#include <algorithm>
// #include <bits/stdc++.h>
using namespace std;
#define inf 0x7f7f7f7f
typedef long long ll;
const int N = 3e4 + 10, M = 1e5 + 10, mod = 1331;
struct node {
int l, r, k;
bool operator < (const node &t) const {// 自定義以區間右端點排序
return r < t.r;
}
}q[M];
int a[N], b[N];
ll tr[N], res[M];//res存儲第i個區間的答案
int vis[N];// 第i個數最後一次出現的比特置
int n, m;
int lowbit(int x)
{
return x & -x;
}
void update(int x, int c) // 比特置x加c
{
for (int i = x;i <= n; i += lowbit(i)) tr[i] += c;
}
ll getsum(int x) // 返回前x個數的和
{
ll res = 0;
for (int i = x;i > 0; i -= lowbit(i)) res += tr[i];
return res;
}
signed main()
{
ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
// freopen("in.in", "r", stdin);freopen("out.out", "w", stdout);
int T;
cin >> T;
while (T --) {
memset(tr, 0, sizeof tr);
memset(vis, 0, sizeof vis);
cin >> n;
for (int i = 1;i <= n;i ++) {
cin >> a[i];
b[i] = a[i];
}
sort(b + 1, b + n + 1);// 排序用於離散化
for (int i = 1;i <= n;i ++) {// 二分函數離散化
a[i] = lower_bound(b + 1, b + n + 1, a[i]) - b;
}
cin >> m;
for (int i = 1;i <= m;i ++) {
cin >> q[i].l >> q[i].r;
q[i].k = i;
}
sort(q + 1, q + m + 1);
int cnt = 1;
for (int i = 1;i <= n;i ++) {
if (vis[a[i]]) update(vis[a[i]], -b[a[i]]);// 如果你出現過,我先將你前面的貢獻删去
update(i, b[a[i]]);// 在當前比特置上加上貢獻
vis[a[i]] = i;// 記錄當前點,用於後面重複了以後進行删除
// 如果右端點與我枚舉的比特置重合了,那麼這個區間裏的數我一定已經初始化好了,然後去求這個區間
while (cnt <= m && q[cnt].r == i) {
res[q[cnt].k] = getsum(q[cnt].r) - getsum(q[cnt].l - 1);
cnt ++;
}
}
// cout << m << "---" << endl;
for (int i = 1;i <= m;i ++) {
cout << res[i] << endl;
}
}
return 0;
}
版权声明:本文为[_Aking]所创,转载请带上原文链接,感谢。 https://gsmany.com/2021/09/20210919141343909w.html