裏紙

ほぼ競プロ、たまに日記

Codechef October Challenge 2017 - Shooting on the array

問題

Contest Page | CodeChef

問題概要

長さnの数列aが与えられる。xy平面を考えて、数列に対してi(1 \le n \le n)番目の要素について、(i,0)(i,a_i )を端点とする線分を引く。以下の2種類のクエリが合計でQ個与えられるのでそれを処理せよ:

  • + i X :  a_i += X
  • ? i L R : x軸に平行な光が、R-L+1個発射されると想定する。j(1 \le j \le R-L+1)番目の光は、座標 (i-0.5 , L+j-1.5)からx軸正の方向に向けて発射されるが、一度線分にぶつかると光は遮断される。このとき、光が当たっている線分の個数を答える。

2つ目のクエリについては、光はクエリごとに独立で考える → 毎回このクエリごとに光源を設置して、このクエリに答えたらその光源はなくなる。

  •  1 \le n \le 10^6
  •  1 \le a_i \le 10^9
  •  1 \le Q \le 10^5
  •  0 \le X \le 10000
  •  1 \le L \le R \le 10^9

イデア

まず、2つ目の幾何的なクエリは分かりにくいので、幾何的な要素を除外する:

i \le jであるjについて、高さ k-0.5 (L \le k \le R)の光が当たるということは、a_j \ge kであり、 i \le l \lt jについて、 a_l \lt a_jである必要がある。このことから、光がj番目の線分にあたるための線分の条件は、

  •  a_j \ge L (これ以上無いとそもそも光が当たらない)
  •  a_l \lt a_j (i \le l \lt j)
  •  max(a_l) \lt R (i \le l \lt j) (これがないとj番目の線分に光が届かない)

なので、これを満たすようなj (i \le j)の個数を数える、ということになる。

3番目の条件はR \le a_hを満たす最小のh (i \le h)を見つけ、h\lt jの部分については考えないということにすれば良いので、i \le j \le hについて考えることで、3番目の条件を除外できる。hを見つけることを考えると、maxを取るsegment tree上を二分探索することでO(logn)で求めることが出来る。

以下、1,2の条件のみを考える。Rはもう関係ないので、今segment tree上で [l , r)にパラメータLという3つ組のクエリに答えることを考えよう。この区間に対して、L = -1というパラメータが与えられた場合の答えをxとして持っているとする。

区間の要素が1つだけならこのクエリに対する答えは自明で、L以上なら1になる。

区間の要素が2つ以上ならば、2つの区間に分けて再帰的に解いていくことになる。左側、右側のそれぞれのノードに対する答えをx_{left} , x_{right}とする。左側の子ノードの区間に対して、その区間の最大値で以下のように場合分け出来る:

  • 最大値がL未満の場合: 左側の子ノードの区間で光が当たる線分は自明に0本なので、右側の区間にのみ再帰を伸ばしていけばよい。
  • 最大値がL以上の場合: 左側の子ノードの区間で光が当たる線分は1本以上あるので、左側の区間再帰を伸ばす。その時に、右側を考えると、パラメータLに依存せず答えが一定になることがわかる(左側にL以上の値があれば、そちらがよりstrictな条件になるから)。そして、それはx - x_lで得られる(x_rではないことに注意)。

つまり、どちらの場合にせよ伸ばす再帰の方向はどちらか1つなので、この再帰はsegment tree上でO(logn)でクエリに答えることが出来る。実際のクエリでは、これが、O(logn)個の頂点で発生するので、O((logn)^2)で処理できる。

また、更新クエリに関しては点更新なので、事前計算で持っておくべき値xについて、計算し直す必要のあるノードがO(logn)個あるので、O((logn)^2)で処理できる。

実装(C++)

#include <bits/stdc++.h>
using namespace std;
using ll = long long;
#define rep(i,n) for(int (i)=0;(i)<(int)(n);++(i))
#define all(x) (x).begin(),(x).end()
#define pb push_back
#define fi first
#define se second

using pi = pair<int,int>;

struct SegTree{
    int n;
    vector<ll> dat,x;

    SegTree(int _n){
        n=1;
        while(n<_n) n*=2;

        dat=vector<ll>(2*n-1,0);
        x=vector<ll>(2*n-1,1);
    }

    int count(int k, int L)
    {
        // 葉ノード
        if(k>=n-1) return dat[k]>L;

        if(dat[2*k+1]<L) return count(2*k+2,L);

        int ret = count(2*k+1,L);
        ret += x[k] - x[2*k+1];
        return ret;
    }

    void add(int k, ll a){
        k+=n-1;
        dat[k]+=a;
        //更新
        while(k>0){
            k=(k-1)/2;
            dat[k] = max(dat[2*k+1],dat[2*k+2]);
            x[k] = x[2*k+1] + count(2*k+2,dat[2*k+1]);
        }
    }

    ll _maxquery(int a, int b, int k, int l, int r){
        if(r<=a || b<=l) return -1;

        if(a<=l && r<=b) return dat[k];

        ll vl=_maxquery(a,b,2*k+1,l,(l+r)/2);
        ll vr=_maxquery(a,b,2*k+2,(l+r)/2,r);
        return max(vl,vr);
    }
    //[a,b)
    ll maxquery(int a, int b){
        return _maxquery(a,b,0,0,n);
    }

    int _find_h(int a, int b, int k, int l, int r, ll val)
    {
        // この区間の最大値がval未満
        if(dat[k]<val) return r;
        // 区間内に値が1つ
        if(l==r-1) return l;

        // 左の子ノードの終点がa以下
        if((l+r)/2<=a) return _find_h(a,b,2*k+2,(l+r)/2,r,val);

        int ret = _find_h(a,b,2*k+1,l,(l+r)/2,val);
        if(ret<(l+r)/2) return ret;
        return _find_h(a,b,2*k+2,(l+r)/2,r,val);
    }
    //[a,b)でval以上になる最小のindex
    int find_h(int a, int b, ll val)
    {
        return _find_h(a,b,0,0,n,val);
    }

    pi _query(int a, int b, int k, int l, int r, int L)
    {
        if(r<=a || b<=l) return {0,-1};
        if(a<=l && r<=b) return {count(k,L),dat[k]};

        pi vl = _query(a,b,2*k+1,l,(l+r)/2,L);
        pi vr = _query(a,b,2*k+2,(l+r)/2,r,max(L,vl.se));

        return {vl.fi+vr.fi, max(vl.se,vr.se)};
    }

    int query(int k, int L, int R)
    {
        return _query(k,min(n,find_h(k,n,R)+1),0,0,n,L-1).fi;
    }
};

void solve()
{
    int n,Q;
    scanf(" %d %d", &n, &Q);

    SegTree st(n);
    rep(i,n)
    {
        int a;
        scanf(" %d", &a);
        st.add(i,a);
    }

    while(Q--)
    {
        char c;
        int idx;
        scanf(" %c %d", &c, &idx);
        --idx;
        if(c=='+')
        {
            int X;
            scanf(" %d", &X);
            st.add(idx,X);
        }
        else
        {
            int L,R;
            scanf(" %d %d", &L, &R);
            printf("%d\n", st.query(idx,L,R));
        }
    }
}

int main()
{
    int T;
    scanf(" %d", &T);
    while(T--) solve();
    return 0;
}