読者です 読者をやめる 読者になる 読者になる

裏紙

ほぼ競プロ、たまに日記

AGC 003 D - Anticube

問題

D: Anticube - AtCoder Grand Contest 003 | AtCoder

問題概要

素数nの重複を許す整数の集合sが与えられる。この集合の部分集合の中で、その部分集合のどの2つの要素の積も立方数にならない部分集合の最大のサイズを求めよ。

  •  1 \le n \le 10^5
  •  1 \le s_i \le 10^{10}

イデア

まず、2数をかけて立方数にならないようにしたいと考えた時、素因数分解して同じ素因数を3個以上含むとき、立方数になるかどうかという観点ではそれは考慮する必要がないのでその素因数の3乗で割ってしまってよい(例えば、2^4 = 162と同一視してしまって問題ない)。ここではそれを"標準系"と呼ぶことにし、整数Nの標準系をNorm(N)と表す。例えば、144 = 2^4 * 3^2なので、Norm(144) = 2* 3^2 = 18となる。そして、整数tに対してtとの積をとると立方数になるような整数の標準系をPair(t)と定める。これはtに対して一意に決まることが分かる。

さて、集合の各要素について、Norm( s_i )Pair( s_i )が求まれば、それぞれのNorm(t)に対して、その集合内でNorm(t)Pair(t)の個数の多い方を選んで足しあわせていくのが最も多い部分集合のサイズとすることができるようになるわけである。ただし、Norm(t)=1の時に関してはPair(t)=1となってしまうので、それ以外について多い方を選び、Norm(t)=1なる要素があるときには答えを+1すればよいということになる。

次に、各要素についてNorm( s_i )Pair( s_i )を求めることを考えたい。s_iの上限が10^{10}であるから、 \sqrt[3]{10^{10}}までの素数を列挙して、その素数の3乗で割れるかどうか全て試すということができる。ここでNorm( s_i )が求まるが、この時、後の処理のために \sqrt[3]{ s_i }以下のs_iの素因数も知ることができるのでそれを同時に計算しておく。そして、その素因数に関して2つ余った時はPair( s_i )に対して1つ、1つ余った時はPair( s_i )に対して2つ掛けておく。

そして、Norm( s_i )が求まったところで次にPair( s_i )を完成させる。ここで残っていて考慮すべきなのは \sqrt[3]{ s_i }より大きいs_iの素因数になる。そこでさっき求めた \sqrt[3]{ s_i }以下のs_iの素因数全てで割れるだけs_iを割った時その値[は

のどれか3パターンになる。これは、 \sqrt[3]{ s_i }以下のs_iの素因数で割ってしまっているので、それ以上大きい素因数の中からs_iを構成するには多くても2つの素因数を持つことしかありえないからである。そして、この時Pair( s_i )に対して掛けるべき値は「素数の2乗」の場合はこの値の平方根で、それ以外はその値の2乗になる。この値をオーバーフローに注意しながらかけていけばPair( s_i )が完成する。オーバーフローすることが分かったらもうその値は問題で与えられる値の範囲外なので、何にも影響しないので適当な値を設定しておけば良い(コード内ではINFと設定)。素数の2乗になっているかどうかを確認するには、自分の実装では予め\sqrt{ 10^{10} }以下の素数の2乗を入れたsetを用意してそれと照らしあわせて確認する方法を取った。

Norm( s_i )Pair( s_i )が求まったので、最後に答えを求めていく。mapなどを使って、Norm(t)の各値の個数とNorm(t)に対するPair(Norm(t))の対応を保存しておいて、多い方を選択していくようにする。Norm(t) = 1があるときはそれは除外して答えに1足しておくことを忘れずに。

3乗区切りで考えていけば良さそうなところまでは気づいていたけど、この細部の実装は時間内に思いつける気がしないし、解説読みながらでもかなり時間食った...

実装(C++)

#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
#define rep(i,n) for(int (i)=0;(i)<(int)(n);++(i))
#define each(itr,c) for(__typeof(c.begin()) itr=c.begin(); itr!=c.end(); ++itr)
#define all(x) (x).begin(),(x).end()
#define pb push_back
#define fi first
#define se second

const int N=100000;
bool prime[N];
vector<int> p;
int P;

const ll INF=1234567890123LL;

int main()
{
    fill(prime,prime+N,true);
    prime[0]=prime[1]=false;
    for(int i=2; i<N; ++i)
    {
        if(prime[i]) for(int j=2; i*j<N; ++j) prime[i*j]=false;
    }
    rep(i,N) if(prime[i]) p.pb(i);
    P=p.size();

    //素数の2乗
    set<ll> sq;
    rep(i,P) sq.insert((ll)p[i]*p[i]);

    int n;
    scanf(" %d", &n);
    vector<ll> s(n);
    rep(i,n) scanf(" %lld", &s[i]);

    vector<ll> Norm(s), Pair(n,1);
    //s[i]^(1/3)以下のs[i]の素因数
    vector<int> f[100000];
    //Norm(s[i])の計算
    rep(i,n)
    {
        rep(j,P)
        {
            ll cube = (ll)p[j]*p[j]*p[j];
            if(cube>s[i]) break;
            while(Norm[i]%cube==0) Norm[i]/=cube;

            if(s[i]%p[j]==0)
            {
                f[i].pb(p[j]);
                if(Norm[i]%(p[j]*p[j])==0) Pair[i]*=p[j];
                else if(Norm[i]%p[j]==0) Pair[i]*=p[j]*p[j];
            }
        }
    }

    //Pair(s[i])の導出
    rep(i,n)
    {
        ll t=s[i];
        for(const auto &d:f[i]) while(t%d==0) t/=d;

        //素数の2乗
        if(sq.find(t)!=sq.end())
        {
            ll sqt=sqrt(t);

            //オーバーフローがないかチェック
            if(Pair[i] > LLONG_MAX/sqt) Pair[i]=INF;
            else  Pair[i]*=sqt;
        }
        else
        {
            if(t<N)
            {
                if(Pair[i] <= LLONG_MAX/t*t) Pair[i]*=t*t;
            }
            else
            {
                //大きすぎて収まらない
                if(t!=1) Pair[i]=INF;
            }
        }

        if(Pair[i]>10000000000LL) Pair[i]=INF;
    }

    int one=0;
    //normの値,個数
    map<ll,int> Norm_count;
    //norm->pair
    map<ll,ll> Norm_to_Pair;
    rep(i,n)
    {
        if(Norm[i]==1) one=1;
        else
        {
            if(Norm_count.find(Norm[i]) == Norm_count.end())
            {
                Norm_count[Norm[i]]=1;
                Norm_to_Pair[Norm[i]]=Pair[i];
            }
            else ++Norm_count[Norm[i]];
        }
    }

    //答えの導出
    int ans=one;
    set<ll> Not_use;

    for(const auto &x:Norm_count)
    {
        ll norm=x.fi;
        int ct=x.se;

        if(Not_use.find(norm) != Not_use.end()) continue;

        ll pair=Norm_to_Pair[norm];
        if(pair == INF) ans+=ct;
        else
        {
            if(Norm_count.find(pair) != Norm_count.end())
            {
                if(ct>=Norm_count[pair])
                {
                    ans+=ct;
                    Not_use.insert(pair);
                }
            }
            else ans+=ct;
        }
    }

    printf("%d\n", ans);
    return 0;
}