萌新求助,刚学线段树,找错找了好久没发现,求大佬帮助
查看原帖
萌新求助,刚学线段树,找错找了好久没发现,求大佬帮助
114859
shight楼主2020/12/26 10:45
#include<bits/stdc++.h>
using namespace std;
struct tree
{
    long long l,r,val,tag,mu;
}t[1000001];
int a[100001],n,m,mod;
long long ls(long long x){return x*2;}
long long rs(long long x){return x*2+1;}
void build(int id,int l,int r)
{
    t[id].l=l,t[id].r=r,t[id].mu=1;
    if(l==r)
    {
        t[id].val=a[l]%mod;
        return;
    }
    int mid=(l+r)/2;
    build(ls(id),l,mid);
    build(rs(id),mid+1,r);
    t[id].val=(t[ls(id)].val+t[rs(id)].val)%mod;
}
void spread(int id)
{
    t[ls(id)].val+=(long long)((t[id].tag*(t[ls(id)].r-t[ls(id)].l+1)%mod)+t[id].mu*t[ls(id)].val)%mod;
    t[ls(id)].tag=(long long)(t[ls(id)].tag*t[id].mu+t[id].tag)%mod;
    t[ls(id)].mu=(long long)(t[ls(id)].mu*t[id].mu)%mod;
    t[rs(id)].val+=(long long)((t[id].tag*(t[rs(id)].r-t[rs(id)].l+1)%mod)+t[id].mu*t[rs(id)].val)%mod;
    t[rs(id)].tag=(long long)(t[rs(id)].tag*t[id].mu+t[id].tag)%mod;
    t[rs(id)].mu=(long long)(t[rs(id)].mu*t[id].mu)%mod;
    t[id].tag=0,t[id].mu=1;
    return;
}
void add(int id,int l,int r,int val)
{
    if(l<=t[id].l&&r>=t[id].r)
    {
        t[id].val=(long long)(t[id].val+val*(t[id].r-t[id].l+1))%mod;
        t[id].tag=(t[id].tag+val)%mod;
        return;
    }
    spread(id);
    int mid=(t[id].l+t[id].r)/2;
    t[id].val=t[ls(id)].val+t[rs(id)].val;
    if(l<=mid)add(ls(id),l,r,val);
    if(r>=mid+1)add(rs(id),l,r,val);
    t[id].val=(t[ls(id)].val+t[rs(id)].val)%mod;
}
void mu(int id,int l,int r,int val)
{
    if(l<=t[id].l&&r>=t[id].r)
    {
        t[id].val=(t[id].val*val)%mod;
        t[id].tag=(t[id].tag*val)%mod;
        t[id].mu=(t[id].mu*val)%mod;
        return;
    }
    spread(id);
    int mid=(t[id].l+t[id].r)/2;
    t[id].val=t[ls(id)].val+t[rs(id)].val;
    if(l<=mid)mu(ls(id),l,r,val);
    if(r>=mid+1)mu(rs(id),l,r,val);
    t[id].val=(t[ls(id)].val+t[rs(id)].val)%mod;
}
long long solve(int id,int l,int r)
{
    if(l<=t[id].l&&r>=t[id].r)return t[id].val;
    spread(id);
    int mid=(t[id].l+t[id].r)/2;
    long long ans=0;
    if(l<=mid)ans=(ans+solve(ls(id),l,r))%mod;
    if(r>=mid+1)ans=(ans+solve(rs(id),l,r))%mod;
    return ans;
}
int main()
{
    scanf("%d %d %d",&n,&m,&mod);
    for(int i=1;i<=n;i++)scanf("%d",&a[i]);
    build(1,1,n);
    for(int i=1,nw;i<=m;i++)
    {
        scanf("%d",&nw);
        if(nw==1)
        {
            int x,y,k;
            scanf("%d %d %d",&x,&y,&k);
            mu(1,x,y,k);
        }
        else if(nw==2)
        {
            int x,y,k;
            scanf("%d %d %d",&x,&y,&k);
            add(1,x,y,k);
        }
        else if(nw==3)
        {
            int x,y;
            scanf("%d %d",&x,&y);
            printf("%lld\n",solve(1,x,y));
        }
    }
}
2020/12/26 10:45
加载中...