如题,已知一个数列,你需要进行下面三种操作:
1.将某区间每一个数乘上x
2.将某区间每一个数加上x
3.求出某区间每一个数的和
是一道卡了很久的题目.与更朴素的模板的差别在于乘法操作
显然需要两个懒标记tag1和tag2.tag1
标记加法,tag2
标记乘法
接下来考虑两个标记的嵌套关系
注意到如果加法在前:(x + a) * b + c = (x + a + c/b) * b
出现了分数.
所以应当由乘法在前
更新关系为:
void add1(llint pval){
tag1 = (tag1 + pval) % p;
val = (val + length() * pval) % p;
return ;
}
void add2(llint pval){
tag2 = (tag2 * pval) % p;
tag1 = (tag1 * pval) % p;
val = (val * pval) % p;
return ;
}
整体代码:
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <string>
#include <algorithm>
#include <cmath>
#define llint long long
using namespace std;
llint p;
struct node{
node *lson,*rson;
llint l,r;
llint val,tag1,tag2;
void eq (node *a,node *b,llint c,llint d,llint e,llint f,llint g){
lson = a;rson = b;l = c;r = d;val = e;tag1 = f;tag2 = g;
return ;
}
node (llint* arr,llint ll,llint rr){
llint mid = ((rr - ll) >> 1) + ll;
if (ll == rr){
this->eq(NULL,NULL,ll,rr,arr[ll],0,1);
return ;
}
lson = new node(arr,ll,mid),rson = new node(arr,mid + 1,rr);
this->eq(lson,rson,ll,rr,lson->val + rson->val,0,1);
}
llint length(){
return r - l + 1;
}
void maintain(){
if(l == r) return ;
val = (lson->val + rson->val) % p;
return ;
}
//"add" means add value to a node
void add1(llint pval){
tag1 = (tag1 + pval) % p;
val = (val + length() * pval) % p;
return ;
}
void add2(llint pval){
tag2 = (tag2 * pval) % p;
tag1 = (tag1 * pval) % p;
val = (val * pval) % p;
return ;
}
//"down" means push tags to the sons
void down(){
if(lson){
lson->add2(tag2);
lson->add1(tag1);
}
if(rson){
rson->add2(tag2);
rson->add1(tag1);
}
tag1 = 0;tag2 = 1;
return ;
}
void oper1(llint pl,llint pr,llint pval){
if(pl > r || pr < l) return ;
if(pl <= l && pr >= r){
add1(pval);
return ;
}
down();
lson->oper1(pl,pr,pval);
rson->oper1(pl,pr,pval);
maintain();
return ;
}
void oper2(llint pl,llint pr,llint pval){
if(pl > r || pr < l) return ;
if(pl <= l && pr >= r){
add2(pval);
return ;
}
down();
lson->oper2(pl,pr,pval);
rson->oper2(pl,pr,pval);
maintain();
return ;
}
llint query(llint pl,llint pr){
if(pl > r || pr < l) return 0;
if(pl <= l && pr >= r)
return val;
down();
return (lson->query(pl,pr) + rson->query(pl,pr)) % p;
}
}*root;
int main (){
llint n,m,x,y,k,al,arr[300000];
cin >> n >> m >> p;
for (int i = 1;i <= n;++ i)
scanf("%lld",&arr[i]);
root = new node(arr,1,n);
for (int i = 1;i <= m;++ i){
scanf("%lld",&al);
if(al == 1){
scanf("%lld%lld%lld",&x,&y,&k);
root->oper2(x,y,k);
}
if(al == 2){
scanf("%lld%lld%lld",&x,&y,&k);
root->oper1(x,y,k);
}
if(al == 3){
scanf("%lld%lld",&x,&y);
llint ans = root->query(x,y);
printf("%lld\n",ans);
}
}
return 0;
}
以上