- admin's blog
线段树
- @ 2024-1-5 12:05:32
线段树
区间问题线段树,可以实现区间/单点修改和区间/单点查询。
pushup(int x)
向上传递参数,也就是更新值的意思,有的题是求和,有的是异或有的是最大/最小。根据题目要求不同,维护的消息也不同。
void push(int x){ sum[x]=sum[x<<1|1]+sum[x<<1]; }
Build(int x,int l,int r)
一般在使用线段树之前需要对线段树进行构建。 build(int x,int l,int r) 中变量含义: x->当前区间的下标,l,r -> 当前区间的范围 [ l , r ] 的闭区间,当没有建到叶子节点的时候就会一直二分建下去。并且建完后会向上传递更新值。
void build(int x,int l,int r) //建树
{
if( l==r ) {sum[x]=a[l];return;} //如果左右区间相等了,即到了叶子节点,就把值赋上去,并且返回
int mid=l+r>>1; //将该区间二分
build( x<<1,l,mid ); //建立左区间树
build( x<<1|1,mid+1,r); //建立右边区间树
push(x); //将x位置的值更新
}
pushdown(int x,int l,int r)
下放lazy标记,lazy标记就是一个操作,告诉他下面的区间都要做的事情,但是目前我们不需要访问到子区间,所以就暂时打上一个标记,等后续需要用到这个区间的时候再下放标记更新值。
void down(int x,int l,int r) //下放懒操作函数
{
if( !lazy[x] ) return; //如果该位置没有懒标记,就直接返回,不需要更新
int mid=l+r>>1;
sum[x<<1]+=(mid-l+1)\*lazy[x]; //左孩子更新值
sum[x<<1|1]+=(r-mid)\*lazy[x]; //右孩子更新值
lazy[x<<1]+=lazy[x],lazy[x<<1|1]+=lazy[x]; //更新孩子的懒标记
lazy[x]=0; //懒标记已经用了下放
}
注意: 有的题目可能有多个懒标记,所以在更新的时候注意四则运算法则,先乘除后加减。 比如lazy1为加减标记,lazy2为乘法标记,那么更新值应该写成。
sum[lc] = ( sum[lc]*lazy2 + lazy1* (mid-l+1)*lazy2 )
Update(int x,int l,int r,int L,int R)
线段树的更新操作,其中,x为目前l,r区间所在的位置,L 和 R是所需要更新的区间,更新的时候注意下放懒标记。 当l,r是被包含在L,R区间内的时候。l,r区间就可以全部进行更新,反之就需要继续二分区间。
void update(int x,int l,int r,int L,int R,int v)
{ //l,r 为当前区间的范围 L,R为需要修改的区间的范围
//当我们需要修改的区间能覆盖当前区间的时候,就去更新当前区间的值,并结束
if(L<=l and r<=R) { sum[x]+=(r-l+1)\*v;lazy[x]+=v;return; }
// 反之,再当前区间是大于需要修改的区间的,必须再分
int mid=l+r>>1;
//分之前下放懒标记,更新小区间值
down(x,l,r);
if( L<=mid ) update(x<<1,l,mid,L,R,v);
if( R>mid ) update(x<<1|1,mid+1,r,L,R,v);
push(x); //由于小区间的值会更新,使用这里的值需要重新计算
}
getsum(int x,int l,int r,int L,int r)
与更新操作很像,但不过getsum是求和。有的题可能是求最大值和最小值注意据题分析
ll Sum(int x,int l,int r,int L,int R)
{
//求区间和,l,r为当前区间,L,R为需要查询的区间
if( L<=l and r<=R ) return sum[x]; //如果需要查询区间覆盖了当前区间,返回当前区间的值
down(x,l,r); //下放懒标记,更新小区间值
int mid=l+r>>1;
ll ans=0; //记录当前区间答案
if( L<=mid ) ans+=Sum(x<<1,l,mid,L,R); //如果左孩子有在这个查询区间的,加上左孩子的值
if( R>mid ) ans+=Sum(x<<1|1,mid+1,r,L,R); //同理
return ans;
}
main函数
读入叶子节点的数据和操作
cin>>n>>m;
int o,x,y;
ll z;
for(int i=1;i<=n;i++)
{
cin>>a[i];
}
build(1,1,n);
for(int i=1;i<=m;i++)
{
cin>>o;
if( o==1 ) cin>>x>>y>>z,update(1,1,n,x,y,z);
else cin>>x>>y,cout<<Sum(1,1,n,x,y)<<endl;
}
完整代码
#include<bits/stdc++.h>
using namespace std;
const int N=1e5+12;
typedef long long ll; //防止变量太大,超了int
ll sum[N<<2],lazy[N<<2],a[N<<2]; //一般线段树的大小为数据大小的4倍
int n,m; //n为数据大小,m为操作次数
void push(int x){ sum[x]=sum[x<<1|1]+sum[x<<1]; } //每次下放懒操作或者其他建树的时候需要更新值
void down(int x,int l,int r) //下放懒操作函数
{
if( !lazy[x] ) return; //如果该位置没有懒标记,就直接返回,不需要更新
int mid=l+r>>1;
sum[x<<1]+=(mid-l+1)\*lazy[x]; //左孩子更新值
sum[x<<1|1]+=(r-mid)\*lazy[x]; //右孩子更新值
lazy[x<<1]+=lazy[x],lazy[x<<1|1]+=lazy[x]; //更新孩子的懒标记
lazy[x]=0; //懒标记已经用了下放
}
void build(int x,int l,int r) //建树
{
if( l==r ) {sum[x]=a[l];return;} //如果左右区间相等了,即到了叶子节点,就把值赋上去,并且返回
int mid=l+r>>1; //将该区间二分
build( x<<1,l,mid ); //建立左区间树
build( x<<1|1,mid+1,r); //建立右边区间树
push(x); //将x位置的值更新
}
void update(int x,int l,int r,int L,int R,int v)
{ //l,r 为当前区间的范围 L,R为需要修改的区间的范围
//当我们需要修改的区间能覆盖当前区间的时候,就去更新当前区间的值,并结束
if(L<=l and r<=R) { sum[x]+=(r-l+1)\*v;lazy[x]+=v;return; }
// 反之,再当前区间是大于需要修改的区间的,必须再分
int mid=l+r>>1;
//分之前下放懒标记,更新小区间值
down(x,l,r);
if( L<=mid ) update(x<<1,l,mid,L,R,v);
if( R>mid ) update(x<<1|1,mid+1,r,L,R,v);
push(x); //由于小区间的值会更新,使用这里的值需要重新计算
}
ll Sum(int x,int l,int r,int L,int R)
{
//求区间和,l,r为当前区间,L,R为需要查询的区间
if( L<=l and r<=R ) return sum[x]; //如果需要查询区间覆盖了当前区间,返回当前区间的值
down(x,l,r); //下放懒标记,更新小区间值
int mid=l+r>>1;
ll ans=0; //记录当前区间答案
if( L<=mid ) ans+=Sum(x<<1,l,mid,L,R); //如果左孩子有在这个查询区间的,加上左孩子的值
if( R>mid ) ans+=Sum(x<<1|1,mid+1,r,L,R); //同理
return ans;
}
int main(){
cin>>n>>m;
int o,x,y;
ll z;
for(int i=1;i<=n;i++)
{
cin>>a[i];
}
build(1,1,n);
for(int i=1;i<=m;i++)
{
cin>>o;
if( o==1 ) cin>>x>>y>>z,update(1,1,n,x,y,z);
else cin>>x>>y,cout<<Sum(1,1,n,x,y)<<endl;
}
}
有乘积操作的完整代码
#include<bits/stdc++.h>
#define lc x<<1
#define rc x<<1|1
#define int long long
using namespace std;
const int N=1e5+12;
struct node{
long long sum;
long long lazy1,lazy2;
}tree[N<<2];
int a[N],n,m,q,op,u,v,k,mod;
void pushup(int x)
{
tree[x].sum=( tree[ lc ].sum + tree[ rc ].sum )%mod;
}
void pushdown(int x, int l, int r) {
int mid = l + (r - l) / 2;
// 应用到左子节点
tree[lc].lazy1 = (tree[lc].lazy1 * tree[x].lazy2 + tree[x].lazy1) % mod;
tree[lc].lazy2 = (tree[lc].lazy2 * tree[x].lazy2) % mod;
tree[lc].sum = (tree[lc].sum * tree[x].lazy2 + (mid - l + 1) * tree[x].lazy1) % mod;
// 应用到右子节点
tree[rc].lazy1 = (tree[rc].lazy1 * tree[x].lazy2 + tree[x].lazy1) % mod;
tree[rc].lazy2 = (tree[rc].lazy2 * tree[x].lazy2) % mod;
tree[rc].sum = (tree[rc].sum * tree[x].lazy2 + (r - mid) * tree[x].lazy1) % mod;
// 重置当前节点的懒惰标记
tree[x].lazy1 = 0;
tree[x].lazy2 = 1;
}
void build(int x,int l,int r)
{
tree[x].lazy1=0;
tree[x].lazy2=1;
if( l==r )
{
tree[x].sum=a[l];
return ;
}
int mid=l+r>>1;
build(lc,l,mid);
build(rc,mid+1,r);
pushup(x);
//cout<<"x: "<<x<<" l: "<<l<<" r: "<<r<<" sum: "<<tree[x].sum<<endl;
}
void update(int x,int l,int r,int L,int R,int op,int v)
{
//cout<<"x: "<<x<<" l: "<<l<<" r: "<<r<<" sum: "<<tree[x].sum<<endl;
if( L<=l && r<=R )
{
if( op==1 )
{
//cout<<"x: "<<x<<" l: "<<l<<" r: "<<r<<" sum: "<<tree[x].sum<<endl;
tree[x].sum*=v,tree[x].lazy2*=v;
tree[x].lazy1*=v;
}
else{
tree[x].sum+=v*(r-l+1),tree[x].lazy1+=v;
}
tree[x].sum=tree[x].sum%mod;
tree[x].lazy2%=mod;
tree[x].lazy1%=mod;
//cout<<"x: "<<x<<" l: "<<l<<" r: "<<r<<" sum: "<<tree[x].sum<<endl;
return;
}
pushdown(x,l,r);
int mid=l+r>>1;
if( L<=mid )
{
update(lc,l,mid,L,R,op,v);
}
if( R>mid )
{
update(rc,mid+1,r,L,R,op,v);
}
pushup(x);
//cout<<"x: "<<x<<" l: "<<l<<" r: "<<r<<" sum: "<<tree[x].sum<<endl;
}
long long getsum(int x,int l,int r,int L,int R)
{
//cout<<"x: "<<x<<" l: "<<l<<" r: "<<r<<" sum: "<<tree[x].sum<<endl;
if( L<=l && r<=R )
{
return tree[x].sum;
}
pushdown(x,l,r);
long long ans=0;
int mid=l+r>>1;
if( L<=mid )
ans+=getsum(lc,l,mid,L,R);
if( R >mid )
ans+=getsum(rc,mid+1,r,L,R);
ans%=mod;
//cout<<"x: "<<x<<" l: "<<l<<" r: "<<r<<" sum: "<<tree[x].sum<<endl;
return ans;
}
int read()
{
char c;
int ans=0;
c=getchar();
while( c<'0' || c>'9' ) c=getchar();
while( c>='0' && c<='9' )
{
ans=ans*10+c-'0';
c=getchar();
}
return ans;
}
main(){
n=read();
m=read();
mod=read();
for(int i=1;i<=n;i++) a[i]=read();
build(1,1,n);
while(m--)
{
op=read();
if( op==1 )
{
u=read();
v=read();
k=read();
update(1,1,n,u,v,op,k);
}
else if( op==2 )
{
u=read();
v=read();
k=read();
update(1,1,n,u,v,op,k);
}
else{
u=read();
v=read();
printf("%lld\n",getsum(1,1,n,u,v));
}
}
}