0%

Luogu3373线段树模板2|题解

如题,已知一个数列,你需要进行下面三种操作:

1.将某区间每一个数乘上x

2.将某区间每一个数加上x

3.求出某区间每一个数的和

是一道卡了很久的题目.与更朴素的模板的差别在于乘法操作

显然需要两个懒标记tag1和tag2.tag1标记加法,tag2标记乘法

接下来考虑两个标记的嵌套关系

注意到如果加法在前:(x + a) * b + c = (x + a + c/b) * b

出现了分数.

所以应当由乘法在前

更新关系为:

void add1(llint pval){
          tag1 = (tag1 + pval) % p;
          val = (val + length() * pval) % p;
          return ;
}
void add2(llint pval){
          tag2 = (tag2 * pval) % p;
          tag1 = (tag1 * pval) % p;
          val = (val * pval) % p;
          return ;
}

整体代码:

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <string>
#include <algorithm>
#include <cmath>
#define llint long long
using namespace std;
llint p;
struct node{
        node *lson,*rson;
        llint l,r;
        llint val,tag1,tag2;
        void eq (node *a,node *b,llint c,llint d,llint e,llint f,llint g){
                lson = a;rson = b;l = c;r = d;val = e;tag1 = f;tag2 = g;
                return ;
        }
        node (llint* arr,llint ll,llint rr){
                llint mid = ((rr - ll) >> 1) + ll;
                if (ll == rr){
                        this->eq(NULL,NULL,ll,rr,arr[ll],0,1);
                        return ;
                }
                lson = new node(arr,ll,mid),rson = new node(arr,mid + 1,rr);
                this->eq(lson,rson,ll,rr,lson->val + rson->val,0,1);
        }
        llint length(){
                return r - l + 1;
        }
        void maintain(){
                if(l == r) return ;
                val = (lson->val + rson->val) % p;
                return ;
        }
        //"add" means add value to a node
        void add1(llint pval){
                tag1 = (tag1 + pval) % p;
                val = (val + length() * pval) % p;
                return ;
        }
        void add2(llint pval){
                tag2 = (tag2 * pval) % p;
                tag1 = (tag1 * pval) % p;
                val = (val * pval) % p;
                return ;
        }
        //"down" means push tags to the sons
        void down(){
                if(lson){
                        lson->add2(tag2);
                        lson->add1(tag1);
                }
                if(rson){
                        rson->add2(tag2);
                        rson->add1(tag1);
                }
                tag1 = 0;tag2 = 1;
                return ;
        }
        void oper1(llint pl,llint pr,llint pval){
                if(pl > r || pr < l) return ;
                if(pl <= l && pr >= r){
                        add1(pval);
                        return ;
                }
                down();
                lson->oper1(pl,pr,pval);
                rson->oper1(pl,pr,pval);
                maintain();
                return ;
        }
        void oper2(llint pl,llint pr,llint pval){
                if(pl > r || pr < l) return ;
                if(pl <= l && pr >= r){
                        add2(pval);
                        return ;
                }
                down();
                lson->oper2(pl,pr,pval);
                rson->oper2(pl,pr,pval);
                maintain();
                return ;
        }
        llint query(llint pl,llint pr){
                if(pl > r || pr < l) return 0;
                if(pl <= l && pr >= r)
                        return val;
                down();
                return (lson->query(pl,pr) + rson->query(pl,pr)) % p;
        }
}*root;
int main (){
        llint n,m,x,y,k,al,arr[300000];
        cin >> n >> m >> p;
        for (int i = 1;i <= n;++ i)
                scanf("%lld",&arr[i]);
        root = new node(arr,1,n);
        for (int i = 1;i <= m;++ i){
                scanf("%lld",&al);
                if(al == 1){
                        scanf("%lld%lld%lld",&x,&y,&k);
                        root->oper2(x,y,k);
                }
                if(al == 2){
                        scanf("%lld%lld%lld",&x,&y,&k);
                        root->oper1(x,y,k);
                }
                if(al == 3){
                        scanf("%lld%lld",&x,&y);
                        llint ans = root->query(x,y);
                        printf("%lld\n",ans);
                }
        }
        return 0;
}

以上