裏紙

ほぼ競プロ、たまに日記

ARC 058 E - 和風いろはちゃん / Iroha and Haiku

問題

E: 和風いろはちゃん / Iroha and Haiku - AtCoder Regular Contest 058 | AtCoder

問題概要

長さNの数列aを考える.数列の各要素a_iについて,値は1以上10以下である.数列aXYZを含むものは何通り有るか答えよ.

ただし,XYZを含むとは:

  •  a_x + a_{x+1} + ... a_{y-1} = X
  •  a_y + a_{y+1} + ... a_{z-1} = Y
  •  a_z + a_{z+1} + ... a_{w-1} = Z

を満たす0 \le x \lt y \lt z \lt w \le Nが存在すると定義される.

  •  3 \le N \le 40
  •  1 \le X \le 5
  •  1 \le Y \le 7
  •  1 \le Z \le 5

イデア

解説を読みながら自分なりにまとめてみた.

まず初めに,コンテスト中に悩んだがどうしようもなくなってしまった方針について触れておくと,x,y,z,wの位置を全探索して,それに対応する組み合わせを足していくというものである.xより左とwより右は自由に選び,各区間では重複組み合わせの考え方を利用すれば全ての組み合わせを列挙出来ると考えたが,この数え上げ方だと重複が発生してしまう.xより左を自由に選んだ時にその区間にもXYZを含むような数列ができてしまい,同じ数列を2回数え上げていることになってしまうからである.これは右も同様であり,複数の区間XYZを含むような数列は存在しうるのでこの方針は非常にややこしくなってしまい,手がつけられなくなってしまった.

そこで,以下ではXYZを「含まない」数列の種類を数え上げることにしよう.数列の先頭から順番に見ていき,その位置に1~10の値をおいた時にXYZを「含まない」ようにできるか判定しながら再帰を進めていくようなイメージになる.

さて,このようなXYZを「含まない」という判定をするためにはどうすればいいか,となるが制約を見ると X+Y+Z=17とあるので,直前の16個の値が全てが1だった時にこの16個の値を保存していないと次の数列の値を決定することはできない.しかし,16個の値を常に保存するのはメモリ的にも時間的にも間に合わない.考えてみると,常に直前の値16個を保存しておくことが必要になるわけではない.保存しておくべきものは「直前に現れた区間の数値の合計が16以下になる位置」までで十分なことがわかる.これをどのように保存すればいいだろうか.

それは,次のように符号化すると良い.1="1", 2="10", 3="100", ...という符号化によって,それらを結合することで直前に現れた数字を表現する.例えば,1,2,3なら"110100"という文字列によって表される.すると,文字列の長さがそのまま数値の合計になって現れるので,「直前に現れた区間の数値の合計がX+Y+Z以下になる位置」までの数字の状態が分かることになる.あとは,これによって数え上げて組み合わせの数を全体から引いて答えとすればよい.

また,各状態において次の数1~10が来ても問題ないかを事前計算しておくことで,高速に動作する(参考:事前計算していないsubmission事前計算をしたsubmission(下の実装と同じ)).

実装(C++)

#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
#define rep(i,n) for(int (i)=0;(i)<(int)(n);++(i))
#define each(itr,c) for(__typeof(c.begin()) itr=c.begin(); itr!=c.end(); ++itr)
#define all(x) (x).begin(),(x).end()
#define mp make_pair
#define pb push_back
#define fi first
#define se second

const ll mod=1e9+7;

int N,X,Y,Z;
int S;

bool nx[1<<16][11];
ll dp[41][1<<16];

ll dfs(int now, int state)
{
    if(dp[now][state]>=0) return dp[now][state];
    if(now==N) return 1;

    ll ret=0;
    //次に選ぶ数
    for(int i=1; i<=10; ++i)
    {
        if(nx[state][i])
        {
            ret+=dfs(now+1, ((state<<i)+(1<<(i-1)))%(1<<S));
            ret%=mod;
        }
    }

    return dp[now][state]=ret;
}

int main()
{
    cin >>N >>X >>Y >>Z;
    S=X+Y+Z-1;

    //状態maskの時に,次の数iが来ていいかチェックしておく
    rep(mask,1<<S)
    {
        //直前の状態を復元
        vector<int> b;
        int st=0;
        rep(i,S)
        {
            if(mask>>i&1)
            {
                b.pb(i-st+1);
                st=i+1;
            }
        }

        int B=b.size();
        //jを次に選んだ時にXYZを含まないかチェック
        for(int i=1; i<=10; ++i)
        {
            bool ok=true;
            //直前から順番に見ていき貪欲に足していく
            int x=0,y=0,z=i;
            int bidx=0;
            while(bidx<B && z<Z) z+=b[bidx++];
            while(bidx<B && y<Y) y+=b[bidx++];
            while(bidx<B && x<X) x+=b[bidx++];
            if(x==X && y==Y && z==Z) ok=false;

            nx[mask][i]=ok;
        }
    }

    //全体は10^N通り
    ll p10=1;
    rep(i,N) p10=(p10*10)%mod;

    memset(dp,-1,sizeof(dp));
    //XYZを含まない数列の個数
    ll r=dfs(0,0);
    //全体から引く
    ll ans=(p10-r+mod)%mod;
    cout << ans << endl;
    return 0;
}