裏紙

ほぼ競プロ、たまに日記

CS Academy #20 - Palindromic Concatenation

問題

Archive

問題概要

文字列がn個与えられる。i番目の文字列をs_iと表す時、s_is_jの連結(この順番に並べて結合した文字列)が回文になるような組(i,j)(i \neq j)の個数を求めよ。

  • 1 \le n \le 10^5
  • 1 \le \sum_{i=1}^{n} | s_i | \le 10^5
  • s_iは英小文字のみで構成される
  • (i,j)(j,i)は区別される

イデア

s_i + s_jが回文になる時、どのような状況になれば条件を満たすかを考えてみる。この条件は文字列の長さに依存するので、以下のように3つに条件を分けて考える:

  1.  | s_i | = | s_j |
  2.  | s_i | \lt | s_j |
  3.  | s_i | \gt | s_j |

1のときに満たすべき条件はシンプルで、s_ireverseしたらs_jに一致することである。

次に、2のときに満たすべき条件を考えてみる。s_iの方が短いので、回文の中心はs_jの途中の位置に来ることになる。よって、それぞれの文字列の長さをI , Jとおけば

  • s_ireverseした文字列とs_jの末尾からI文字が一致
  • s_jの先頭からJ-I文字が回文になっている

という2つの条件を同時に満たす必要があることが分かる。

3についても対称に考えれば、

  • s_jreverseした文字列とs_iの先頭からJ文字が一致
  • s_iの末尾からI-J文字が回文になっている

という2つの条件を同時に満たす必要があることが分かる。

nの制約から、それぞれのペアをとってきて比較するということは間に合わなさそうなので、別の方針を考える必要がある。 そこで、ハッシュ値を利用して各s_iとペアを作ることが出来る文字列の個数を探す。

条件1に関しては、各文字列をreverseしたときのハッシュ値をカウントしておけば簡単に計算できる。以下では2,3について考える。

まず、n個の文字列は長い方から順に処理していくことにする。そして、mapにハッシュ値をカウントしていく。そのときに、先頭からx文字が回文になっていれば残りの末尾の文字のハッシュ値を+1し、末尾からx文字が回文になっていれば残りの先頭の文字のハッシュ値を+1とする。

この、「回文になっているかどうか」という判定は、各文字列に対して元の文字列と、reverseした後の文字列でそれぞれローリングハッシュを構築し、そのハッシュ値の比較をすることでO(1)で判定することが可能になる。

各文字列s_iに対して、ペアを作ることが出来る文字列の個数はこのs_ireverseしたもののハッシュ値のカウントに一致する。

実装(C++)

#include <bits/stdc++.h>
using namespace std;
using ll = long long;
#define rep(i,n) for(int (i)=0;(i)<(int)(n);++(i))
#define all(x) (x).begin(),(x).end()
#define pb push_back
#define fi first
#define se second

struct RollingHash{
    static const int MD = 3;
    static const vector<ll> hash_base, hash_mod;

    int n;
    vector<ll> hs[MD], pw[MD];

    RollingHash(){}
    RollingHash(const string &s){
        n = s.size();
        rep(i,MD){
            hs[i].assign(n+1,0);
            pw[i].assign(n+1,0);
            hs[i][0] = 0;
            pw[i][0] = 1;
            rep(j,n){
                pw[i][j+1] = pw[i][j]*hash_base[i] % hash_mod[i];
                hs[i][j+1] = (hs[i][j]*hash_base[i]+s[j]) % hash_mod[i];
            }
        }
    }

    // 1-index
    ll hash_value(int l, int r, int i){
        return ((hs[i][r] - hs[i][l]*pw[i][r-l])%hash_mod[i]+hash_mod[i])%hash_mod[i];
    }

    bool match(int l1, int r1, int l2, int r2){
        bool ret = true;
        rep(i,MD) ret &= (hash_value(l1-1,r1,i) == hash_value(l2-1,r2,i));
        return ret;
    }

    vector<ll> calc(int l, int r){
        vector<ll> ret(MD);
        rep(i,MD) ret[i]=hash_value(l-1,r,i);
        return ret;
    }
};
const vector<ll> RollingHash::hash_base{1009,1021,1013};
const vector<ll> RollingHash::hash_mod{1000000009,1000000007,1000000021};

const int N=100000;
vector<string> s[N+1];

int main()
{
    cin.tie(0);ios::sync_with_stdio(false);

    int n;
    cin >>n;

    rep(i,n)
    {
        string tmp;
        cin >>tmp;
        s[tmp.size()].pb(tmp);
    }

    map<vector<ll>,ll> ct;
    ll ans = 0;

    for(int L=N; L>0; --L)
    {
        int SZ=s[L].size();
        if(!SZ) continue;

        vector<string> t(SZ);
        vector<RollingHash> hh(SZ),th(SZ);
        rep(i,SZ)
        {
            t[i]=s[L][i];
            reverse(all(t[i]));

            hh[i] = RollingHash(s[L][i]);
            th[i] = RollingHash(t[i]);
        }

        // 長さが同じもの
        map<vector<ll>,int> same;
        rep(i,SZ) ++same[th[i].calc(1,L)];
        rep(i,SZ) ans += same[hh[i].calc(1,L)]-(s[L][i]==t[i]);

        // 長さが違うもの
        rep(i,SZ) ans += ct[th[i].calc(1,L)];

        // ハッシュのカウントを更新
        rep(i,SZ)for(int pl=1; pl<L; ++pl)
        {
            // s[L][i]の先頭からpl文字が回文になっているか?
            if(hh[i].calc(1,pl) == th[i].calc(L-pl+1,L)) ++ct[hh[i].calc(pl+1,L)];
            // s[L][i]の末尾からpl文字が回文になっているか?
            if(hh[i].calc(L-pl+1,L) == th[i].calc(1,pl)) ++ct[hh[i].calc(1,L-pl)];
        }
    }

    cout << ans << endl;
    return 0;
}