裏紙

ほぼ競プロ、たまに日記

CODE FESTIVAL 2016 Final E - Cookies

問題

E: Cookies - CODE FESTIVAL 2016 Final (Parallel) | AtCoder

問題概要

はじめ、1秒に1枚のクッキーを焼くことが出来る。そして、クッキーを全て食べつくすという行動を取ることが出来る。枚数に関わらず、これにはA秒かかる(食べている間はクッキーを焼くことは出来ない)。しかし、直前にクッキーがx枚溜まっていたら、食べた後は1秒にx枚のクッキーを焼くことができるようになる。

N枚のクッキーを用意したいと思う時、それに必要な時間の最小値を求めよ。

  •  1 \le N \le 10^{12}
  •  0 \le A \le 10^{12}
  • Aは整数

イデア

まず、クッキーを食べる間隔についてだが、クッキーを食べてから次にクッキーを食べるまでに最低でも2秒はクッキーを焼かないと最適にはなりえない。なぜなら、今の生産効率がpだとして、1秒だけクッキーを焼いて食べてもその後の生産効率はpのままであり、食べる時間Aだけ無駄になるからである。

そうすると、最短でクッキーを食べ続ける行動をとっても生産効率は2倍ずつ上がっていく。なので、生産効率がNに到達するまでにクッキーを食べる回数はO(logN)回ということになり、これ以上の回数クッキーを食べることに意味はない。

クッキーを食べる回数をk回に決め打ちして、そのときにN枚作るための時間の最小値を求めていく。k回食べるときに、そのクッキーを焼く間隔をs_1 , s_2, ... ,s_{k+1}秒とすると、かかる時間はA * k + \sum_{i=1}^{k+1} s_iで、焼くことのできるクッキーの枚数はs_1 * s_2 * ... * s_{k+1}枚である。

焼けるクッキーの枚数についてなぜこうなるのかというと、順番にシミュレーションしていけばこうなることが分かる。はじめは1枚/s焼けて、s_1秒後に食べるのでs_1枚/s焼けるようになる。次はs_1枚/s焼けて、s_2秒後に食べるのでs_1 * s_2枚/s焼けるようになる。... これを繰り返していくと、最終的な状態は上で述べたようになっている。

クッキーの枚数s_1 * s_2 * ... * s_{k+1} \ge Nの条件を満たしつつ、かかる時間A * k + \sum_{i=1}^{k+1} s_iを最小化したい。初項は一定なので、実質 \sum_{i=1}^{k+1} s_iを最小化したい問題になった。

この時、積をできるだけ大きくするにはそれぞれの項の差が1以下になるようにするのが最適になっている。a \ge bとしたときに、Mab - M(a+1)(b-1) = M(a+1-b) \gt 0であるから、これが正しいことが分かった。

s_iの最大値を二分探索し、それをmとすると、(m-1) * (m-1) * ... * (m-1) * m  * m * ... * mのような形になるので、m-1の項の個数を全探索して、最小値を更新していけば良い。オーバーフローには気をつける(N以上になった瞬間に打ち切ってしまえば良い)。

実装(C++)

#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
#define rep(i,n) for(int (i)=0;(i)<(int)(n);++(i))
#define each(itr,c) for(__typeof(c.begin()) itr=c.begin(); itr!=c.end(); ++itr)
#define all(x) (x).begin(),(x).end()
#define pb push_back
#define fi first
#define se second

int main()
{
    ll n,a;
    cin >>n >>a;

    ll ans=n;

    for(int k=1; k<=40; ++k)
    {
        ll l=0, r=1000000;
        while(r-l>1)
        {
            ll m=(l+r)/2;

            bool ok=false;
            ll val=1;
            rep(i,k+1)
            {
                val*=m;
                if(val>=n)
                {
                    ok=true;
                    break;
                }
            }

            if(ok) r=m;
            else l=m;
        }

        rep(i,k+1)
        {
            ll val=1;

            bool ok=false;
            rep(j,i)
            {
                val*=r-1;
                if(val>=n)
                {
                    ok=true;
                    break;
                }
            }
            if(!ok)
            {
                rep(j,k+1-i)
                {
                    val*=r;
                    if(val>=n)
                    {
                        ok=true;
                        break;
                    }
                }
            }

            if(ok)
            {
                ll tmp = a*k + i*(r-1) + (k+1-i)*r;
                ans=min(ans,tmp);
            }
        }
    }

    cout << ans << endl;
    return 0;
}