CF 599D - Spongebob and Squares
- 問題
http://codeforces.com/contest/599/problem/D
今回はA,B,Cで3完できたと思ったらBで若干考え違いして落としてしまった。Dは考え方は合ってたけど最後の詰めが甘かった。つらい。こういう考えついた問題をACに持っていけるようにしたい。。。
- 概要
自然数xが与えられる。(1<=x<=10^18)
n*mの格子状の四角形のテーブルを考えた時、そのテーブル内にある正方形の個数がxに一致する(n,m)の組の個数とそれぞれのn,mを答えよ。
例えば、3*5のテーブルなら、1辺の長さが1の正方形は15個、2の正方形は8個、3の正方形は5個なので、15+8+3=26という具合。図も丁寧についてるし分かる。
- アイデア
まず、nとmについては対称なので、n<=mの場合の考察で十分。
次にn*mのテーブルに対して、正方形が現れる個数は何個になるのかを考えた→S(n,m)と表す。
すると、1辺の長さとしてありうるのは1~nのn種類。(∵n<=m)
1辺の長さが:
1の正方形はn*m個、
2の正方形は上から1,2行目,2,3行目,...,(n-1),n行目の2行を使い、列に対しても同じ考察をすれば、
(n-1)*(m-1)個、
...
と考察していくことにより、1辺の長さがkである正方形の個数a_kは
a_k = (n-(k-1)) * (m-(k-1))
という結論になる。そしてS(n,m)はa_kのk=[1:n]の和であるから、式変形して
a_k = k^2 - (m+n+2)*k + (n+1)(m+1)
で、これに高校数学で習った公式を適用して
S(n,m) = n(n+1)(2n+1)/6 - (m+n+2)*n(n+1)/2 + (n+1)(m+1)*n
である。
そして、x=S(n,m)となるn,mを求めればよい、ということになる。方程式を立てて
x = (S(n,m)=) n(n+1)(2n+1)/6 - (m+n+2)*n(n+1)/2 + (n+1)(m+1)*n
6倍して
6x = n(n+1)(2n+1) - 3*(m+n+2)*n(n+1) + 6*(n+1)(m+1)*n
右辺を整理すると
6x = m*3n(n+1)-n^3+n ...(# : この式についてあとで言及する)
ここまで来ると明らかだが、これはnを固定すればmの1次方程式に帰着できる。mについて解くと
m = (6x+n^3-n) / 3n(n+1)
となる。mが整数値ならば、このnとmは答えの1つであるということが分かる。
つまり、nを固定した時、それに対応するmは1つまたは無いということがO(1)で計算できる。
あとはnを1から順に走査するだけ...なのだが、ここで自分は最後の壁に当たってしまった。
自分はコンテスト終了5分前ぐらいにちょうど上の考えにたどり着き、nを1から走らせて
「nとmがひっくり返っている解を見つけたところでループ終了」
というように書いていた。そしてTLEした。
http://codeforces.com/contest/599/submission/14390191
どの入力でTLEしたかというと、
200000800200
とある。これの解は、
4
1 200000800200
4 20000080021
20000080021 4
200000800200 1
↑こうなっている。
そして、自分の提出したコードは、この3番目の組、「20000080021 4」が見つかるまでループが回り続けることになる。これは当然TLEする。
当然、もっと早くにループを打ち切ることが可能である。ここで、上の式(#)について、
6x = m*3n(n+1)-n^3+n
これがmについての方程式だが、いまmには制約があり、n<=mである。この不等式を上の式に適用して
6x = m*3n(n+1)-n^3+n >= n*3n(n+1)-n^3+n = 2n^3 *+ 3n^2 + n
であり、つまり右辺は2n^3 *+ 3n^2 + n未満の値を取り得ない。この2次式はnについて単調増加なのは明らかなので、2n^3 *+ 3n^2 + nが6xより大きくなった時点でそれ以上のnにn<=mとなるような(n,m)の組は答えたり得ないということになる。
x<=10^18なので、nとしてはだいたい10^6より大きくなることはありえない。これは計算時間として十分間に合う。
最後に、得られた答えを折り返して終わり。n=mの組を折り返して2つ同じものを出さないようにだけ気をつける。
- 実装(C++)
#include <iostream> #include <string> #include <vector> #include <queue> #include <stack> #include <map> #include <algorithm> #include <set> #include <sstream> #include <utility> #include <cstdio> #include <cstdlib> #include <cstring> #include <cmath> #include <cctype> #include <climits> using namespace std; typedef long long ll; #define foreach(itr,c) for(__typeof(c.begin()) itr=c.begin(); itr!=c.end(); itr++) ll calc_m(ll n, ll x){ ll ret=-1; if( (6*x-n+n*n*n) % (3*n*(n+1)) == 0){ ret = (6*x-n+n*n*n) / (3*n*(n+1)); } //cout << "ret =" << ret <<endl; return ret; } int main(){ ll x; cin >> x; ll n=1; vector< pair<ll,ll> > ans; while(6*x>=2*n*n*n+3*n*n+n){ ll m=calc_m(n,x); if(m!=-1) ans.push_back(make_pair(n,m)); n++; } //折り返し ll lim=ans.size(); if(ans[ans.size()-1].first == ans[ans.size()-1].second) lim--; for(ll i=lim-1; i>=0; --i){ ans.push_back(make_pair(ans[i].second,ans[i].first)); } //output answer cout << ans.size() << endl; for(ll i=0; i<ans.size(); ++i){ cout << ans[i].first << " " << ans[i].second << endl; } }