読者です 読者をやめる 読者になる 読者になる

裏紙

ほぼ競プロ、たまに日記

FHC 2017 R2 C - Fighting all the Zombies

programming FHC

問題

Fighting all the Zombies | Facebook Hacker Cup 2017 Round 2

問題概要

RPGをやっている。主人公は魔法使いである。

N体のゾンビがいる洞窟がある。i番目のゾンビの強さはiである。いま、レベリングのために、この洞窟をM回周回しようと考えている。この時、主人公は洞窟に入るたびにN体全てのゾンビを倒して洞窟から出てきて、再び洞窟に入るとN体全てのゾンビが復活している。

主人公は、N本の杖を持っている。i番目の杖の強さはiである。それぞれの杖は洞窟に入るごとに一度しか使えず、杖を使うと一体のゾンビを倒すことが出来る。

はじめ、主人公はN種類の呪文を覚えている。i番目の呪文は強さiを持っており、それはi番目の杖を使うことで発動できる。

さて、主人公はM回この洞窟にいくことになるが、i回目に洞窟に入る前に新たにS_i個の呪文を覚える。そのS_i個の呪文の強さは全てW_iであり、全て強さZ_iの杖によって発動可能な呪文である。ただし、呪文の強さと杖の強さに大きな差はなく、具体的には | W_i - Z_i | \le 1を満たす。また、全ての呪文は区別して考える。

この状況で、i回目に洞窟にいった時のN体のゾンビの倒し方が何通りあるか知りたい。この時、少なくとも1体のゾンビに対して別の呪文を使うような組み合わせは区別して数える。

i回目に洞窟に行った時のゾンビの倒し方の組み合わせがP_i通りあるとしたときの\displaystyle \sum_{i=1}^{M} P_i10^9 + 7で割った余りを答えよ。

  •  1 \le N, M \le 800000
  •  1 \le W_i , Z_i \le N
  •  1 \le S_i \le 10^9

イデア

設定が複雑で分かりにくい。1つずつまとめていく。

まず、I番目のゾンビを倒すためには強さiの呪文しか使えないということなので、初期状態では組み合わせは1通りしかない。また、杖に注目すると、新しい呪文を覚えていったとしても、i番目の杖で発動可能な呪文の強さはi-1ii+1の3種類に限られるということが分かる。

さて、強さ1の杖から順番に、どの強さの呪文を発動するかを考えていく。1の杖では、(使える呪文が1つ以上あるのなら)強さ1か2の呪文を使うことが出来る。そして、2,3,4,… と弱い方から順に呪文の強さを選んでいき、次に強さiの杖でどの呪文を発動するかということを考えると、この状況で強さi+1が埋まっていることは有り得ないが、ii+1が埋まっていることはあり得る。このことから次のようなDPを考える。

dp[i][prev][now] = i番目の杖に注目していて、強さi-1の呪文は既に使われたか(prev)、強さiの呪文は既に使われたか(now)という状況の時の組み合わせの個数

初期状態としてはdp[1][0][1]=1ということになり、最終的には求めたいものはdp[N+1][1][0]となる。ただ、このDPの計算をM回もやろうとするとO(NM)となって間に合わない。ただ、遷移に注目してみると周回している間、その1回1回の間で遷移が変化するのは1箇所だけなので、その無駄をなくしてなんとかしたいと考えられる。

ここで、DPのiとi+1の間の遷移がどうなっているのかを図で書いてみるとこのようになる(はじめはdp[0][0][1]=1から始まっているので、意味のある遷移だけを書いた)。

f:id:imulan:20170202002346p:plain

この図を見て分かる通り、dpとして式を定義したはいいものの、結局dp[*][0][1]とdp[*][1][0]しか必要ないことに気づく。ということで、前者をa_i,後者をb_iと置く。そして、杖ii-1, i, i+1のそれぞれの強さの呪文をいくつずつ使えるのかをC_{i,0} , C_{i,1}, C_{i,2}とすると、次のような漸化式が成り立つ。

a_{i+1} = b_i * C_{i,2}

b_{i+1} = a_i * C_{i,0} + b_i * C_{i,1}

これらは行列の形に直せて、a_1 = 1, b_1 = 0という初期値も合わせると、N個の行列の積によって途中の遷移が表されることになり、答えはb_{N+1}ということになる。

この形になったところで、M回の周回時に毎回この行列積を初めから計算していたのでは間に合わないので、更新されたところだけを計算し直すという方法を取る。

コンテスト本番時は平方分割によってO(\sqrt{N})での実現を図ったが、思った以上にケースごとの処理時間が長く、時間内に提出できなかった。そこで、平方分割ではなく、SegTreeによってこの更新を行っていき、答えを求める。それによって、1回ごとにO(logN)での更新が可能になる。さすがにNが大きくなるとこの2つにも差が出るのだなあという感じがする。

平方分割よりかなり速くなったけど4分くらい入力に対してかかった…まあギリギリか。ただ意外と行列をSegtreeにのせる実装が最大値のSegtreeのテンプレがあったらちょっと書き換えるだけで出来て、意外と大変じゃないのかという感想。

実装(C++)

#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
#define rep(i,n) for(int (i)=0;(i)<(int)(n);++(i))
#define each(itr,c) for(__typeof(c.begin()) itr=c.begin(); itr!=c.end(); ++itr)
#define all(x) (x).begin(),(x).end()
#define pb push_back
#define fi first
#define se second

const ll mod=1e9+7;

typedef vector<ll> vl;
typedef vector<vl> mat;

mat mul(const mat &a, const mat &b)
{
    mat c(2,vl(2));
    rep(i,2)rep(j,2)
    {
        rep(k,2) (c[i][j]+=a[i][k]*b[k][j])%=mod;
    }
    return c;
}

vl mulv(const mat &a, const vl &b)
{
    vl ret(2);
    ret[0]=(a[0][0]*b[0]+a[0][1]*b[1])%mod;
    ret[1]=(a[1][0]*b[0]+a[1][1]*b[1])%mod;
    return ret;
}

struct MatSegTree{
    int n; vector<mat> dat;
    //初期化
    MatSegTree(int _n){
        n=1;
        while(n<_n) n*=2;
        dat=vector<mat>(2*n-1,mat({{0,0},{0,1}}));
    }
    //k番目(0-indexed)のx番の要素に+a
    void add(int k, int x, ll a){
        k+=n-1;

        int p,q;
        if(x==0){p=1;q=0;}
        else if(x==1){p=1;q=1;}
        else{p=0;q=1;}

        (dat[k][p][q]+=a)%=mod;

        //更新
        while(k>0){
            k=(k-1)/2;
            dat[k]=mul(dat[2*k+1],dat[2*k+2]);
        }
    }
    //内部的に投げられるクエリ
    mat _query(int a, int b, int k, int l, int r){
        if(r<=a || b<=l) return mat({{0,0},{0,1}});

        if(a<=l && r<=b) return dat[k];

        mat VL=_query(a,b,2*k+1,l,(l+r)/2);
        mat VR=_query(a,b,2*k+2,(l+r)/2,r);
        return mul(VL,VR);
    }
    //[a,b)の行列積を計算
    mat query(int a, int b){
        return _query(a,b,0,0,n);
    }
};

ll solve()
{
    int N,M;
    cin >>N >>M;

    vector<ll> w(M+1),d(M+1),z(M+1),s(M+1);
    ll A,B;

    cin >>w[1] >>A >>B;
    for(int i=2; i<=M; ++i) w[i]=((A*w[i-1]+B)%N)+1;

    cin >>d[1] >>A >>B;
    for(int i=2; i<=M; ++i) d[i]=(A*d[i-1]+B)%3;

    for(int i=1; i<=M; ++i) z[i]=max(1LL, min((ll)N,w[i]+d[i]-1));

    cin >>s[1] >>A >>B;
    for(int i=2; i<=M; ++i) s[i]=((A*s[i-1]+B)%1000000000)+1;

    MatSegTree st(N);
    ll ret=0;
    for(int i=1; i<=M; ++i)
    {
        // update
        int idx = N-w[i];

        int k;
        if(z[i]<w[i]) k=0;
        else if(z[i]==w[i]) k=1;
        else k=2;

        st.add(idx,k,s[i]);

        // calc
        mat m = st.query(0,N);

        vl res = mulv(m,vl({0,1}));
        (ret+=res[1])%=mod;
    }
    return ret;
}

int main()
{
    int T;
    cin >>T;
    rep(i,T) printf("Case #%d: %lld\n", i+1, solve());
    return 0;
}