这是一般的线段树代码(注意sum函数):
#include <iostream>
#include <cstdio>
using namespace std;
#define ll long long
struct Tree {
ll l , r , lazy , k;
};
ll n , m , a[100050] , cnt;
Tree tr[400050];
void unionn (ll num)
{
tr[num].k = tr[ tr[num].l ].k + tr[ tr[num].r ].k;
return;
}
inline void f (ll num , ll l , ll r , ll k)
{
tr[num].lazy += k;
tr[num].k += (r - l + 1) * k;
}
inline void down_out (ll num , ll l , ll r)
{
const ll mid = (l+r) >> 1;
f(tr[num].l , l , mid , tr[num].lazy);
f(tr[num].r , mid+1 , r , tr[num].lazy);
tr[num].lazy = 0;
}
void update (ll num , ll arr_l , ll arr_r , ll l , ll r , ll k)
{
if(arr_l <= l && arr_r >= r)
{
tr[num].k += (r - l + 1) * k;
tr[num].lazy += k;
return;
}
down_out (num , l , r);
const ll mid = (l+r) >> 1;
if(arr_l <= mid)
update (tr[num].l , arr_l , arr_r , l , mid , k);
if(arr_r >= mid+1)
update (tr[num].r , arr_l , arr_r , mid+1 , r , k);
unionn (num);
return;
}
ll sum (ll num , ll arr_l , ll arr_r , ll l , ll r)
{
down_out (num , l , r);
//unionn (num);
if(arr_l <= l && arr_r >= r)
return tr[num].k;
ll ans = 0 , mid = (l+r) >> 1;
if(arr_l <= mid)
ans += sum (tr[num].l , arr_l , arr_r , l , mid);
if(arr_r >= mid+1)
ans += sum (tr[num].r , arr_l , arr_r , mid+1 , r);
return ans;
}
void build (ll num , ll l , ll r)
{
if(l == r)
{
tr[num].k = a[l];
return;
}
tr[num].l = ++ cnt;
tr[num].r = ++ cnt;
ll mid = (l+r) >> 1;
build (tr[num].l , l , mid);
build (tr[num].r , mid+1 , r);
unionn (num);
}
int main()
{
cnt = 1;
scanf ("%lld%lld",&n,&m);
for(int i = 1 ; i <= n ; i ++)
scanf ("%lld",&a[i]);
build (1 , 1 , n);
while (m --)
{
ll x , y , k , op;
scanf("%lld%lld%lld",&op,&x,&y);
if(op == 1)
{
scanf("%lld",&k);
update(1 , x , y , 1 , n , k);
}
else {
printf("%lld\n",sum(1 , x , y , 1 , n));
}
}
return 0;
}
而这是经过略微修改的线段树代码(无法通过此题,去掉了sum函数中注释的unionn函数):
#include <iostream>
#include <cstdio>
using namespace std;
#define ll long long
struct Tree {
ll l , r , lazy , k;
};
ll n , m , a[100050] , cnt;
Tree tr[400050];
void unionn (ll num)
{
tr[num].k = tr[ tr[num].l ].k + tr[ tr[num].r ].k;
return;
}
inline void f (ll num , ll l , ll r , ll k)
{
tr[num].lazy += k;
tr[num].k += (r - l + 1) * k;
}
inline void down_out (ll num , ll l , ll r)
{
const ll mid = (l+r) >> 1;
f(tr[num].l , l , mid , tr[num].lazy);
f(tr[num].r , mid+1 , r , tr[num].lazy);
tr[num].lazy = 0;
}
void update (ll num , ll arr_l , ll arr_r , ll l , ll r , ll k)
{
if(arr_l <= l && arr_r >= r)
{
tr[num].k += (r - l + 1) * k;
tr[num].lazy += k;
return;
}
down_out (num , l , r);
const ll mid = (l+r) >> 1;
if(arr_l <= mid)
update (tr[num].l , arr_l , arr_r , l , mid , k);
if(arr_r >= mid+1)
update (tr[num].r , arr_l , arr_r , mid+1 , r , k);
unionn (num);
return;
}
ll sum (ll num , ll arr_l , ll arr_r , ll l , ll r)
{
down_out (num , l , r);
unionn (num);
if(arr_l <= l && arr_r >= r)
return tr[num].k;
ll ans = 0 , mid = (l+r) >> 1;
if(arr_l <= mid)
ans += sum (tr[num].l , arr_l , arr_r , l , mid);
if(arr_r >= mid+1)
ans += sum (tr[num].r , arr_l , arr_r , mid+1 , r);
return ans;
}
void build (ll num , ll l , ll r)
{
if(l == r)
{
tr[num].k = a[l];
return;
}
tr[num].l = ++ cnt;
tr[num].r = ++ cnt;
ll mid = (l+r) >> 1;
build (tr[num].l , l , mid);
build (tr[num].r , mid+1 , r);
unionn (num);
}
int main()
{
cnt = 1;
scanf ("%lld%lld",&n,&m);
for(int i = 1 ; i <= n ; i ++)
scanf ("%lld",&a[i]);
build (1 , 1 , n);
while (m --)
{
ll x , y , k , op;
scanf("%lld%lld%lld",&op,&x,&y);
if(op == 1)
{
scanf("%lld",&k);
update(1 , x , y , 1 , n , k);
}
else {
printf("%lld\n",sum(1 , x , y , 1 , n));
}
}
return 0;
}
按照我的思路来讲,如果这个区间的 lazytag 下放后再合并应该和原值相同,而不会出现答案错误的情况,希望大佬求解。
样例输入:
5 5
1 5 4 2 3
2 2 4
1 2 3 2
2 3 4
1 1 5 1
2 1 4
样例输出:
11
8
20