裏紙

ほぼ競プロ、たまに日記

SPOJ METEORS - Meteors

問題

SPOJ.com - Problem METEORS

問題概要

ある惑星の軌道上には隕石が降り注いでくる。その軌道をM個の領域に分割し、1,2, ..., Mと番号をふる。このとき、番号は連続的に振られ、M番目の領域と1番目の領域は繋がっている。そのM個の領域に分けた領域に対してそれぞれ1つずつ拠点が設置されており、その拠点は国o_iのものである。

それぞれの拠点では、この隕石を収集することができる。各国iは隕石をp_i個収集したいと思っている。今、隕石の落下予測がK周期先のぶんまで予測されている。その予想はそれぞれl_i , r_i , a_i3値で特徴付けられており、

  • l_i \le r_iのとき、 l_i ,l_i + 1, ... , r_i -1 , r_iの範囲にa_i個ずつ
  • l_i \gt r_iのとき、 l_i ,l_i + 1, ... ,m, 1, 2, ... , r_i -1 , r_iの範囲にa_i個ずつ

隕石が降ってくることを表す。

このとき、各国iが目標の隕石の個数p_i個を収集するまでにかかる周期を求めよ。K周期後にも回収しきれない場合にはNIEと出力せよ。

  •  1 \le M \le 300000
  •  1 \le N \le 300000
  •  1 \le o_i \le N
  •  1 \le  p_i \le 10^9
  •  1 \le K \le 300000
  •  1 \le l_i , r_i \le M
  •  1 \le a_i \le 10^9

イデア

各国ごとに、どのタイミングでできるのかを知りたいとなったときに、二分探索によって探すのが良さそう。そこで、二分探索でx周期番目までを見る時はクエリを最初から順番に見ていって区間に足していくということをすれば良い。

この区間に足していくという処理も愚直にやると時間がかかってしまうので、今回は自分がライブラリとして持っていたSegment Treeを少し改造したものを作って、区間に対してaddをするクエリと、ある1つの位置の値をgetするクエリのどちらもO(logM)で実現出来るものを用意して処理することにした(動作については後述)。

これによって、クエリの処理はO(KlogM)、その位置が条件を満たすかどうかの判定をO(NlogM)により行えるようになる。そして、これを各国に対して行うので、全体として計算量はO(N(K+N)logMlogK)となって、明らかに間に合わない。

そこで、各国に対して二分探索するのではなく、このN個の国を同時に処理していくParallel Binary Searchをする。結局、二分探索の部分でSegment Treeを構成しているが、これを各国ごとに毎回初めから構築しているのがとてもムダな部分になっていて、時間もかかってしまっている部分になる。そこをまとめていこうという考え方である。

今回作成したSegment Treeの挙動

今回実現したかったものは、区間に対するaddと1つの位置の値を取り出すgetがどちらもO(logM)でできるようなsegtreeである。

まず、addクエリについて考えていくと、次の図のような挙動にした。範囲にすっぽり収まった時点でそこにaddして止めておくということをする。

  • [3,5]に+3

f:id:imulan:20161017170427p:plain

  • [4,8]に+5

f:id:imulan:20161017170631p:plain

そして、このような状態でgetしたい時には葉からスタートして真上に辿って足し上げることでその位置の値が何か、つまり今までに回収できた隕石の個数がわかるようになる。

さて、これによってsegtreeを構成する必要のある回数はO(logK)回となる。その各回について、N個の国のクエリを平行に分割してsegtreeを構成しつつクエリを処理する。これによって全体の計算量はO((K+N)logMlogK)となり、間に合う。

また、クエリを処理している時、各国が回収できる隕石の個数がlong longの範囲を超えうるので、目標に到達した時点即数え上げをやめないといけないことに注意。

実装(C++)

#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
#define rep(i,n) for(int (i)=0;(i)<(int)(n);++(i))
#define each(itr,c) for(__typeof(c.begin()) itr=c.begin(); itr!=c.end(); ++itr)
#define all(x) (x).begin(),(x).end()
#define pb push_back
#define fi first
#define se second

struct SumSegTree{
    int n; vector<ll> dat;
    //初期化
    SumSegTree(int _n){
        n=1;
        while(n<_n) n*=2;
        dat=vector<ll>(2*n-1,0);
    }
    ll get(int idx){
        idx+=n-1;
        ll ret=dat[idx];
        while(idx>0){
            idx=(idx-1)/2;
            ret+=dat[idx];
        }
        return ret;
    }
    //内部的に投げられるクエリ
    void _query(int a, int b, ll v, int k, int l, int r){
        if(r<=a || b<=l) return;

        if(a<=l && r<=b) dat[k]+=v;
        else{
            _query(a,b,v,2*k+1,l,(l+r)/2);
            _query(a,b,v,2*k+2,(l+r)/2,r);
        }
    }
    //[a,b)に+v
    void add(int a, int b, ll v){
        _query(a,b,v,0,0,n);
    }
};

const int N=300000;

int n,m,k;
// stations that i-th state has
vector<int> state[N];
int target[N];
int l[N],r[N];
ll a[N];

int L[N],R[N];

int main()
{
    //input
    scanf(" %d %d", &n, &m);

    rep(i,m)
    {
        int o;
        scanf(" %d", &o);
        --o;
        state[o].pb(i);
    }
    rep(i,n) scanf(" %d", &target[i]);
    scanf(" %d", &k);
    rep(i,k)
    {
        scanf(" %d %d %lld", &l[i], &r[i], &a[i]);
        --l[i];
        --r[i];
    }

    // initialize
    rep(i,n)
    {
        L[i]=-1;
        R[i]=k;
    }

    rep(T,20)
    {
        vector<int> q[N];
        rep(i,n) q[(L[i]+R[i])/2].pb(i);

        SumSegTree st(m+1);
        rep(i,k)
        {
            if(l[i]<=r[i]) st.add(l[i],r[i]+1,a[i]);
            else
            {
                st.add(l[i],m,a[i]);
                st.add(0,r[i]+1,a[i]);
            }

            rep(j,q[i].size())
            {
                int s=q[i][j];

                ll tmp=0;
                rep(x,state[s].size())
                {
                    tmp+=st.get(state[s][x]);
                    if(target[s]<=tmp) break;
                }

                if(target[s]<=tmp) R[s]=i;
                else L[s]=i;
            }
        }
    }

    rep(i,n)
    {
        if(R[i]<k) printf("%d\n", R[i]+1);
        else printf("NIE\n");
    }

    return 0;
}