玄关求调
查看原帖
玄关求调
268984
KW_KL楼主2024/11/24 21:33

有大佬能帮忙看一下我打的史山吗?孩子两个样例都能过但不知道为什么就是会WA有修改a的点(

已经改得要和题解长一个样了为什么还过不了

原题qwq

#include<bits/stdc++.h>
#define l_id(id) id << 1
#define r_id(id) id << 1 | 1
#define inf 1e16
using namespace std;
typedef long long ll;
const int N = 1e5 + 5,M = 1e5 + 5,mod = 1e9 + 7;
int n,m,cnt,ans,c,p[N];
ll x = -1,y = -1,a[N];
struct Tree{
	int p,l,r,len,len_l,len_r,tag_p,len_short_interval,*short_interval;
	ll sum,sum_add,sum_mul,sum_l,sum_r,val_l,val_r,tag_a;
	vector<int> long_interval;
}tre[N << 2];
struct Tre{
	int p,len,len_l,len_r;
	ll sum,sum_l,sum_r;
};
inline int read()
{
	int x = 0,f = 1;
	char c = getchar();
	while (!isdigit(c))
	{
		if (c == 45)
			f = -1;
		c = getchar();
	}
	while (isdigit(c))
	{
		x = (x << 1) + (x << 3) + (c ^ 48);
		c = getchar();
	}
	return x * f;
}
inline ll readd()
{
	ll x = 0,f = 1;
	char c = getchar();
	while (!isdigit(c))
	{
		if (c == 45)
			f = -1;
		c = getchar();
	}
	while (isdigit(c))
	{
		x = (x << 1) + (x << 3) + (c ^ 48);
		c = getchar();
	}
	return x * f;
}
inline void print(ll x)
{
	if (x < 0)
	{
		putchar(45);
		x = -x;
	}
	if (x > 9)
		print(x / 10);
	putchar(x % 10 ^ 48);
}
inline ll abss(ll x)
{
	return (x ^ (x >> 63)) - (x >> 63);
}
inline ll gcd(ll a,ll b)
{
	if (b)
		while ((a %= b) && (b %= a));
	return a + b;
}
inline ll max(ll x,ll y)
{
	return (y & ((x - y) >> 63)) | (x & (~(x - y) >> 63));
}
inline ll min(ll x,ll y)
{
	return (x & ((x - y) >> 63)) | (y & (~(x - y) >> 63));
}
inline ll ksm(ll a,int b)
{
	ll ans = 1,base = a;
	while (b)
    {
		if (b & 1)
			ans = ans * base % mod;
        base = base * base % mod;
		b >>= 1;
	}
	return ans;
}
inline void push_up_a(int id,Tree &l,Tree &r)
{
	tre[id].val_l = l.val_l;
	tre[id].val_r = r.val_r;
	tre[id].sum_add = (l.sum_add + r.sum_add) % mod;
	tre[id].sum_mul = l.sum_mul * r.sum_mul % mod;
	tre[id].p = r.p;
	if (l.p)	
		tre[id].sum = ((((l.sum + r.sum) % mod - l.sum_r) % mod - r.sum_l + mod) % mod + (l.sum_r * r.sum_l) % mod + mod) % mod;
	else
		tre[id].sum = (l.sum + r.sum) % mod;
	if (l.p && l.len_l == l.len)
	{
		tre[id].sum_l = l.sum * r.sum_l % mod;
		tre[id].len_l = l.len + r.len_l;
	}
	else
	{
		tre[id].sum_l = l.sum_l;
		tre[id].len_l = l.len_l;
	}
	if (l.p && r.len_r == r.len)
	{
		tre[id].sum_r = l.sum_r * r.sum % mod;
		tre[id].len_r = l.len_r + r.len;
	}
	else
	{
		tre[id].sum_r = r.sum_r;
		tre[id].len_r = r.len_r;
	}
}
inline void push_up(int id,Tree &l,Tree &r)
{
//	printf(">%d %d %d %lld %d\n",id,tre[id].l,tre[id].r,tre[id].sum,tre[id].len);
	push_up_a(id,l,r);
//	printf(">>%d %d %d %lld %d\n",id,tre[id].l,tre[id].r,tre[id].sum,tre[id].len);
	for (register int i = 0;i++ < tre[id].len_short_interval;tre[id].short_interval[i] = 0);
	tre[id].long_interval.clear();
	for (register int i = 0;i++ < sqrt(l.len);tre[id].short_interval[i] += l.short_interval[i]);
	for (register int i = 0;i++ < sqrt(r.len);tre[id].short_interval[i] += r.short_interval[i]);
	vector<int> :: iterator st1 = l.long_interval.begin(),st2 = r.long_interval.begin(),ed1 = l.long_interval.end(),ed2 = r.long_interval.end();
	while (st1 != ed1 && st2 != ed2)
		if ((*st1) < (*st2))
		{
			if ((*st1) <= tre[id].len_short_interval)
				tre[id].short_interval[(*st1)]++;
			else
				tre[id].long_interval.push_back((*st1));
			st1++;
		}
		else
			if ((*st1) > (*st2))
			{
				if ((*st2) <= tre[id].len_short_interval)
					tre[id].short_interval[(*st2)]++;
				else
					tre[id].long_interval.push_back((*st2));
				st2++;
			}
			else
			{
				if ((*st1) <= tre[id].len_short_interval)
					tre[id].short_interval[(*st1)] += 2;
				else
				{
					tre[id].long_interval.push_back((*st1));
					tre[id].long_interval.push_back((*st2));
				}
				st1++;
				st2++;
			}
	while (st1 != ed1)
	{
		if ((*st1) <= tre[id].len_short_interval)
			tre[id].short_interval[(*st1)]++;
		else
			tre[id].long_interval.push_back((*st1));
		st1++;
	}
	while (st2 != ed2)
	{
		if ((*st2) <= tre[id].len_short_interval)
			tre[id].short_interval[(*st2)]++;
		else
			tre[id].long_interval.push_back((*st2));
		st2++;
	}
	if (l.p)
	{
		if (l.len_r <= tre[id].len_short_interval)
			tre[id].short_interval[l.len_r]--;
		else
		{
			vector<int> :: iterator it = lower_bound(tre[id].long_interval.begin(),tre[id].long_interval.end(),l.len_r);
			tre[id].long_interval.erase(it);
		}
		if (r.len_l <= tre[id].len_short_interval)
			tre[id].short_interval[r.len_l]--;
		else
		{
			vector<int> :: iterator it = lower_bound(tre[id].long_interval.begin(),tre[id].long_interval.end(),r.len_l);
			tre[id].long_interval.erase(it);
		}
		if (l.len_r + r.len_l <= tre[id].len_short_interval)
			tre[id].short_interval[l.len_r + r.len_l]++;
		else
		{
			vector<int> :: iterator it = lower_bound(tre[id].long_interval.begin(),tre[id].long_interval.end(),l.len_r + r.len_l);
			tre[id].long_interval.insert(it,l.len_r + r.len_l);
		}
	}
}
inline void build(int l,int r,int id)
{
	tre[id].l = l;
	tre[id].r = r;
	tre[id].len = r - l + 1;
	tre[id].len_short_interval = sqrt(tre[id].len);
	tre[id].short_interval = new int[tre[id].len_short_interval + 1];
	tre[id].tag_a = -1;
	tre[id].tag_p = -1;
	if (l == r)
	{
		tre[id].sum = a[l] % mod;
		tre[id].sum_add = a[l] % mod;
		tre[id].sum_mul = a[l] % mod;
		tre[id].sum_l = a[l] % mod;
		tre[id].sum_r = a[l] % mod;
		tre[id].val_l = a[l] % mod;
		tre[id].val_r = a[l] % mod;
		tre[id].len_l = 1;
		tre[id].len_r = 1;
		tre[id].short_interval[1] = 1;
		tre[id].p = p[l];
		return;
	}
	int mid = l + r >> 1;
	build(l,mid,l_id(id));
	build(mid + 1,r,r_id(id));
	push_up(id,tre[l_id(id)],tre[r_id(id)]);
}
inline void push_down_p(int id,int tag_p)
{
	tre[id].p = tag_p;
	tre[id].tag_p = tag_p;
	for (register int i = 0;i++ < tre[id].len_short_interval;tre[id].short_interval[i] = 0);
	tre[id].long_interval.clear();
	if (!tag_p)
	{
		tre[id].sum = tre[id].sum_add;
		tre[id].sum_l = tre[id].val_l;
		tre[id].sum_r = tre[id].val_r;
		tre[id].len_l = 1;
		tre[id].len_r = 1;
		tre[id].short_interval[1] = tre[id].len;
	}
	else
	{
		tre[id].sum = tre[id].sum_mul;
		tre[id].sum_l = tre[id].sum_mul;
		tre[id].sum_r = tre[id].sum_mul;
		tre[id].len_l = tre[id].len;
		tre[id].len_r = tre[id].len;
		if (tre[id].len <= tre[id].len_short_interval)
			tre[id].short_interval[tre[id].len] = 1;
		else
			tre[id].long_interval.push_back(tre[id].len);
	}
}
inline void push_down_a(int id,ll tag_a)
{
	tag_a %= mod;
	tre[id].tag_a = tag_a;
	tre[id].sum_add = tag_a * tre[id].len % mod;
	tre[id].sum_mul = ksm(tag_a,tre[id].len);
	tre[id].sum_l = ksm(tag_a,tre[id].len_l);
	tre[id].sum_r = ksm(tag_a,tre[id].len_r);
	tre[id].sum = 0;
	tre[id].val_l = tag_a;
	tre[id].val_r = tag_a;
	for (register int i = 0;i++ < tre[id].len_short_interval;)
		if (tre[id].short_interval[i])
		{
			if (x != tag_a || c > i)
			{
				y = ksm(tag_a,i);
				x = tag_a;
			}
			else
				y = y * ksm(tag_a,i - c) % mod;
			tre[id].sum = (tre[id].sum + y * tre[id].short_interval[i] + mod) % mod;
			c = i;
		}
	vector<int> :: iterator st = tre[id].long_interval.begin(),ed = tre[id].long_interval.end();
	while (st != ed)
	{
		if (x != tag_a || c > (*st))
		{
			y = ksm(tag_a,(*st));
			x = tag_a;
		}
		else
			y = y * ksm(tag_a,(*st) - c) % mod;
		tre[id].sum = (tre[id].sum + y + mod) % mod;
		c = (*st++);
	}
}
inline void push_down(int id)
{
	if (tre[id].tag_p != -1)
	{
		push_down_p(l_id(id),tre[id].tag_p);
		push_down_p(r_id(id),tre[id].tag_p);
		tre[id].tag_p = -1;
	}
	if (tre[id].tag_a != -1)
	{
		push_down_a(l_id(id),tre[id].tag_a);
		push_down_a(r_id(id),tre[id].tag_a);
		tre[id].tag_a = -1;
	}
}
inline void update_p(int id,int s,int t,int k)
{
	if (tre[id].l == s && tre[id].r == t)
	{
		push_down_p(id,k);
		return;
	}
	push_down(id);
	int mid = tre[id].l + tre[id].r >> 1;
	if (mid >= t)
		update_p(l_id(id),s,t,k);
	else
		if (mid < s)
			update_p(r_id(id),s,t,k);
		else
		{
			update_p(l_id(id),s,mid,k);
			update_p(r_id(id),mid + 1,t,k);
		}
	push_up(id,tre[l_id(id)],tre[r_id(id)]);
}
inline void update_a(int id,int s,int t,ll k)
{
	if (tre[id].l == s && tre[id].r == t)
	{
		push_down_a(id,k);
		return;
	}
	push_down(id);
	int mid = tre[id].l + tre[id].r >> 1;
	if (mid >= t)
		update_a(l_id(id),s,t,k);
	else
		if (mid < s)
			update_a(r_id(id),s,t,k);
		else
		{
			update_a(l_id(id),s,mid,k);
			update_a(r_id(id),mid + 1,t,k);
		}
	push_up(id,tre[l_id(id)],tre[r_id(id)]);
}
inline Tre get_sum(Tre l,Tre r)
{
	Tre res;
	res.len = l.len + r.len;
	res.p = r.p;
	if (l.p)	
		res.sum = ((((l.sum + r.sum) % mod - l.sum_r) % mod - r.sum_l + mod) % mod + l.sum_r * r.sum_l + mod) % mod;
	else
		res.sum = (l.sum + r.sum) % mod;
	if (l.p && l.len_l == l.len)
	{
		res.sum_l = l.sum * r.sum_l % mod;
		res.len_l = l.len + r.len_l;
	}
	else
	{
		res.sum_l = l.sum_l;
		res.len_l = l.len_l;
	}
	if (l.p && r.len_r == r.len)
	{
		res.sum_r = l.sum_r * r.sum % mod;
		res.len_r = l.len_r + r.len;
	}
	else
	{
		res.sum_r = r.sum_r;
		res.len_r = r.len_r;
	}
	return res;
}
inline Tre query(int id,int s,int t)
{
	if (tre[id].l == s && tre[id].r == t)
		return (Tre){tre[id].p,tre[id].len,tre[id].len_l,tre[id].len_r,tre[id].sum,tre[id].sum_l,tre[id].sum_r};
	push_down(id);
	int mid = tre[id].l + tre[id].r >> 1;
	Tre res;
	if (mid >= t)
		res = query(l_id(id),s,t);
	else
		if (mid < s)
			res = query(r_id(id),s,t);
		else
			res = get_sum(query(l_id(id),s,mid),query(r_id(id),mid + 1,t));
	return res;
}
int main()
{
	n = read(),m = read();
	for (register int i = 0;i++ < n;a[i] = readd() % mod);
	for (register int i = 0;i++ < n - 1;p[i] = read());
	build(1,n,1);
	while (m--)
	{
		int op = read(),l = read(),r = read();
		if (op == 1)
			update_a(1,l,r,readd() % mod);
		if (op == 2)
			update_p(1,l,r,read());
		if (op == 3)
		{
			print(query(1,l,r).sum);
			puts("");
		}
	}
	return 0;
}
2024/11/24 21:33
加载中...