裏紙

ほぼ競プロ、たまに日記

AOJ 2022 - Princess, a Cryptanalyst

問題

Princess, a Cryptanalyst | Aizu Online Judge

問題概要

N個の文字列sが与えられる.このN個の文字列に対してSSS(Shortest Secret String)を求めたい.SSSとは,N個の文字列を全て部分文字列として含み,なおかつその長さが最小であるような文字列のことを表す.

与えられるN個の文字列に対して,SSSを求めよ.なお複数ある場合は辞書順最小のものを答えよ.

  •  1 \le N \le 10
  •  1 \le | s_i | \le 10

イデア

初めは単純に順列を試して,共通部分をうまくくっつけながら全通り試すだけでもギリギリ間に合いそうな気がしたけど,ギリギリ間に合わせるのが想定解ではないだろうなと思って,別の方針にたどり着いた.

まず,文字列に対して,全く同じ文字列が複数含まれている場合は1つだけでよくそれ以上同じものがあるのは冗長なのは明らか.また,ある文字列が別の文字列の部分列になっている場合も,それは不要である(例:watermelonとmelonという文字列があるときには,watermelonだけでその2つをカバーできるので,melonはあってもなくても変わらない).なので,そういうものも試す文字列リストから除外してよい.

こうして重複がなく,更にある文字列がほかの文字列の部分文字列になっているということもない文字列リストを作れる.あとはこれを上手い順番に並べて,共通部分があれば片方を取り除いてつなげていくということをすればSSSが作れそうという感じがする.

ここでは,それを全通り試すのではなく,bitDPによる解法を考えてみる.dp[直前に使用した文字列の番号][どの番号の文字列を既に使用したか]=長さが最も短く,辞書順最小の文字列とする.

この2番目の「どの番号の文字列を既に使用したか」というところを集合のビット表現で管理する.例えば,N=4のときに0101というのは,(小さい方のビットから見ることで)0番目と2番目の文字列を既に使用したという状態を表す.そして,全ての文字列を使用し終わると状態は必ず1111,つまり10進数で言えば15になるので,答えはdp[0][15],dp[1][15],dp[2][15],dp[3][15]のどれかということになる.

まず,1番目の文字列としてどれを選ぶかを決め,それをdpで初期化する.そして,状態の遷移は2番目のindexである状態が文字列を追加するごとに必ず増加していく方向になるので,状態が小さい方からloopを回していくことを考えよう.

現在の状態がmaskのときに,次に使おうと思っているものがi番目の文字列だとする.maskの状態を見て,i番目の文字列が既に使用されていたということが分かった場合,そのような遷移は不可能であるということがわかる.ただ,そうでないときにはこの遷移を考えることが出来る.そうすると遷移後の状態は現在の状態からi番目の文字列を使用するということなので,mask+(1<<i)になる.

そして,maskの状態に至る直前にはj番目の文字列を使っていたとする.すると,この時に得られる「長さが最も短く,辞書順最小の文字列」はまさにdp[j][mask]ということになる.ここから,共通部分をうまく考慮してこのdp[j][mask]の末尾にs_iを足していく.

これを繰り返していくことによって,最終的にdp[i][(1<<N)-1]( 0 \le i \le N-1)のうち最も小さいものを答えとすればよいことになる.時間計算量はO(N^2 * 2^N)

また,補足として,文字列を結合するときの共通部分の導出に関して,dpのループに入る前に前処理をしておくと楽にできる.具体的にはcov[i][j]=s[i],s[j]の順序で文字列を結合するときの共通部分の文字数というものを用意して,これを計算しておくと良い.

この計算方法は,単純な方法としては共通部分の長さをkに決め打ちして,s[i]の最後からk文字の部分列とs[j]の最初からk文字の部分列が一致しているかを調べれば良い.

文字列を値として持つdpを書いてみたが,経験がなかったのでうまくいくか不安だったけど,無事ACだったのでよかった.

実装(C++)

#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
#define rep(i,n) for(int (i)=0;(i)<(int)(n);++(i))
#define each(itr,c) for(__typeof(c.begin()) itr=c.begin(); itr!=c.end(); ++itr)
#define all(x) (x).begin(),(x).end()
#define mp make_pair
#define pb push_back
#define fi first
#define se second

typedef vector<int> vi;

int main()
{
    int n;
    while(cin >>n,n)
    {
        vector<string> s(n);
        rep(i,n) cin >>s[i];

        //重複要素の削除
        sort(all(s));
        s.erase(unique(all(s)),s.end());
        n=s.size();

        //s[i]がs[j]の部分列になっているかを調べる
        vector<bool> rm(n,false);
        rep(i,n)rep(j,n)
        {
            if(i==j) continue;

            int a=s[i].size(), b=s[j].size();
            if(a>b) continue;

            rep(k,b-a+1)
            {
                string tmp=s[j].substr(k,a);
                if(tmp==s[i]) rm[i]=true;
            }
        }

        vector<string> t(s);
        s.clear();
        rep(i,t.size())
        {
            if(rm[i]) --n;
            else s.pb(t[i]);
        }

        //s[i],s[j]の順に並べるときの文字被りの数
        vector<vi> cov(n,vi(n,0));
        rep(i,n)rep(j,n)
        {
            if(i==j) continue;

            int lim=min(s[i].size(),s[j].size());
            int ct=0;
            for(int k=1; k<=lim; ++k)
            {
                string a=s[i].substr(s[i].size()-k,k);
                string b=s[j].substr(0,k);
                if(a==b) ct=k;
            }
            cov[i][j]=ct;
        }

        string INF="";
        rep(i,100) INF+="z";

        //initialize
        string dp[10][1024];
        fill(dp[0],dp[10],INF);
        rep(i,n) dp[i][1<<i]=s[i];

        //現在の状態
        for(int mask=1; mask<(1<<n); ++mask)
        {
            //次に使おうと思ってるもの
            rep(i,n)
            {
                //already used
                if(mask>>i&1) continue;

                //どこからの遷移か(dp[j][mask])
                rep(j,n)
                {
                    //not used
                    if((mask>>j&1)==0) continue;

                    string nx=dp[j][mask];
                    string add=s[i].substr(cov[j][i]);
                    nx+=add;

                    if(dp[i][mask+(1<<i)].size()>nx.size())
                        dp[i][mask+(1<<i)]=nx;
                    else if(dp[i][mask+(1<<i)].size()==nx.size())
                        dp[i][mask+(1<<i)]=min(dp[i][mask+(1<<i)],nx);
                }
            }
        }

        string ans=INF;
        rep(i,n)
        {
            if(dp[i][(1<<n)-1].size()<ans.size()) ans=dp[i][(1<<n)-1];
            else if(dp[i][(1<<n)-1].size()==ans.size()) ans=min(ans,dp[i][(1<<n)-1]);
        }
        cout << ans << endl;
    }
    return 0;
}