裏紙

ほぼ競プロ、たまに日記

ARC 064 F - Rotated Palindromes

問題

F: Rotated Palindromes - AtCoder Regular Contest 064 | AtCoder

問題概要

長さNの数列aとしてあり得るものを全て用意する。ただし、a

  •  1 \le a_i \le K
  • aは回文(つまりaaの逆順に並べたものは一致)

という条件を満たす。このようにして得られたaに対して、「先頭の要素を末尾へ移動する」という操作を好きな回数行う。

これによって得られる最終的なaとしてあり得るものが何通りあるかをmod 10^9 + 7で答えよ。

  •  1 \le N \le 10^9
  •  1 \le K \le 10^9

イデア

例えば、数列の構成が長さ6で、abccbaのような数列だったとする。このときは、cyclic shiftにより構成される数列の種類は順に見ていくと

  • abccba (回文)
  • bccbaa
  • ccbaab
  • cbaabc (回文)
  • baabcc
  • aabccb

のようになっている。しかしながら、abababのような数列を考えるとabababとbababaの2種類しか現れないことがわかる。

このように数列内の最小周期によってその数列から生成される組み合わせの個数というのが変わってくる。そこで、この周期を全探索して、その結果を足し合わせることで全体を求めるということを考える。ただし、この全探索をする時は、「最短」周期を全探索するということを心に留めておく。

周期を全探索にするにあたって、この周期がNの約数であることは明らか。この制約ならNの約数の個数は多くても1500個程度であることがわかっているので、全探索できる。

さて、今周期をdで固定する。さて、aを全体で見て回文になっているのだから、この周期dの数列も回文になっていなければおかしい(更にその周期dの数列の中にd未満の周期が存在していてはいけない)。さて、このことから、最低周期をdに持つ数列aとして最初に設定できるものの個数をdp_d とすると、次の式が成り立つ:

 \displaystyle dp_d = K^{\lceil \frac{d}{2} \rceil} - \sum_{i \lt d , \  d\% i = 0} dp_i

回文なので、前半を自由に設定できるとして、ただしそこからそれ未満の周期が出来ている個数を引かなければいけないという式の形になっている。このしきにより、約数の個数の2乗のオーダーでこのDPを計算できる。

それができたら、個別に周期dがある回文に対して、cyclic shiftした時に何種類の数列が生成されるかを掛け合わせたものを足し合わせることで最終的な答えが求められる。そこで、dの偶奇によって生成される種類が変わってくる。まず、dが奇数ならその時は中心の数の位置によってd個全てが違うものであると区別がつくのでそのままd個になる。ただし、dが偶数なら上でも見たようにcyclic shiftの途中でもう一度回文が出ていることになり、同じものを2回カウントしている事になっている。よってこの時はd/2個がshiftによって生成されることになる。

実装(C++)

#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
#define rep(i,n) for(int (i)=0;(i)<(int)(n);++(i))
#define each(itr,c) for(__typeof(c.begin()) itr=c.begin(); itr!=c.end(); ++itr)
#define all(x) (x).begin(),(x).end()
#define pb push_back
#define fi first
#define se second

const ll mod=1e9+7;

ll mod_pow(ll x, ll n)
{
    ll pw[40];
    pw[0]=x;
    for(int i=1; i<40; ++i) pw[i]=(pw[i-1]*pw[i-1])%mod;

    ll ret=1;
    rep(i,40)
    {
        if(n>>i&1) (ret*=pw[i])%=mod;
    }
    return ret;
}

vector<int> factor(int n)
{
    vector<int> ret;

    for(int i=1; i*i<=n; ++i)
    {
        if(n%i==0)
        {
            ret.pb(i);
            if(i != n/i) ret.pb(n/i);
        }
    }

    sort(all(ret));
    return ret;
}

int main()
{
    int n,k;
    scanf(" %d %d", &n, &k);

    vector<int> d = factor(n);
    int D=d.size();

    vector<ll> dp(D);
    rep(i,D)
    {
        dp[i] = mod_pow(k,(d[i]+1)/2);

        rep(j,i)if(d[i]%d[j]==0) dp[i] = (dp[i]-dp[j]+mod)%mod;
    }

    ll ans=0;
    rep(i,D)
    {
        if(d[i]%2==0) (ans+=dp[i]*d[i]/2)%=mod;
        else (ans+=dp[i]*d[i])%=mod;
    }

    printf("%lld\n", ans);
    return 0;
}