読者です 読者をやめる 読者になる 読者になる

裏紙

ほぼ競プロ、たまに日記

CF 599D - Spongebob and Squares

  • 問題

http://codeforces.com/contest/599/problem/D

今回はA,B,Cで3完できたと思ったらBで若干考え違いして落としてしまった。Dは考え方は合ってたけど最後の詰めが甘かった。つらい。こういう考えついた問題をACに持っていけるようにしたい。。。

  • 概要

自然数xが与えられる。(1<=x<=10^18)
n*mの格子状の四角形のテーブルを考えた時、そのテーブル内にある正方形の個数がxに一致する(n,m)の組の個数とそれぞれのn,mを答えよ。

例えば、3*5のテーブルなら、1辺の長さが1の正方形は15個、2の正方形は8個、3の正方形は5個なので、15+8+3=26という具合。図も丁寧についてるし分かる。

まず、nとmについては対称なので、n<=mの場合の考察で十分。
次にn*mのテーブルに対して、正方形が現れる個数は何個になるのかを考えた→S(n,m)と表す。
すると、1辺の長さとしてありうるのは1~nのn種類。(∵n<=m)

1辺の長さが:
1の正方形はn*m個、
2の正方形は上から1,2行目,2,3行目,...,(n-1),n行目の2行を使い、列に対しても同じ考察をすれば、
(n-1)*(m-1)個、
...

と考察していくことにより、1辺の長さがkである正方形の個数a_kは
a_k = (n-(k-1)) * (m-(k-1))
という結論になる。そしてS(n,m)はa_kのk=[1:n]の和であるから、式変形して
a_k = k^2 - (m+n+2)*k + (n+1)(m+1)
で、これに高校数学で習った公式を適用して
S(n,m) = n(n+1)(2n+1)/6 - (m+n+2)*n(n+1)/2 + (n+1)(m+1)*n
である。

そして、x=S(n,m)となるn,mを求めればよい、ということになる。方程式を立てて
x = (S(n,m)=) n(n+1)(2n+1)/6 - (m+n+2)*n(n+1)/2 + (n+1)(m+1)*n
6倍して
6x = n(n+1)(2n+1) - 3*(m+n+2)*n(n+1) + 6*(n+1)(m+1)*n
右辺を整理すると
6x = m*3n(n+1)-n^3+n ...(# : この式についてあとで言及する)

ここまで来ると明らかだが、これはnを固定すればmの1次方程式に帰着できる。mについて解くと
m = (6x+n^3-n) / 3n(n+1)
となる。mが整数値ならば、このnとmは答えの1つであるということが分かる。

つまり、nを固定した時、それに対応するmは1つまたは無いということがO(1)で計算できる。

あとはnを1から順に走査するだけ...なのだが、ここで自分は最後の壁に当たってしまった。
自分はコンテスト終了5分前ぐらいにちょうど上の考えにたどり着き、nを1から走らせて
「nとmがひっくり返っている解を見つけたところでループ終了」
というように書いていた。そしてTLEした。
http://codeforces.com/contest/599/submission/14390191

どの入力でTLEしたかというと、
200000800200
とある。これの解は、

4
1 200000800200
4 20000080021
20000080021 4
200000800200 1
↑こうなっている。

そして、自分の提出したコードは、この3番目の組、「20000080021 4」が見つかるまでループが回り続けることになる。これは当然TLEする。

当然、もっと早くにループを打ち切ることが可能である。ここで、上の式(#)について、
6x = m*3n(n+1)-n^3+n
これがmについての方程式だが、いまmには制約があり、n<=mである。この不等式を上の式に適用して
6x = m*3n(n+1)-n^3+n >= n*3n(n+1)-n^3+n = 2n^3 *+ 3n^2 + n

であり、つまり右辺は2n^3 *+ 3n^2 + n未満の値を取り得ない。この2次式はnについて単調増加なのは明らかなので、2n^3 *+ 3n^2 + nが6xより大きくなった時点でそれ以上のnにn<=mとなるような(n,m)の組は答えたり得ないということになる。
x<=10^18なので、nとしてはだいたい10^6より大きくなることはありえない。これは計算時間として十分間に合う。

最後に、得られた答えを折り返して終わり。n=mの組を折り返して2つ同じものを出さないようにだけ気をつける。

#include <iostream>
#include <string>
#include <vector>
#include <queue>
#include <stack>
#include <map>
#include <algorithm>
#include <set>
#include <sstream>
#include <utility>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <cctype>
#include <climits>
using namespace std;

typedef long long ll;
#define foreach(itr,c) for(__typeof(c.begin()) itr=c.begin(); itr!=c.end(); itr++)

ll calc_m(ll n, ll x){
  ll ret=-1;
  if( (6*x-n+n*n*n) % (3*n*(n+1)) == 0){
    ret = (6*x-n+n*n*n) / (3*n*(n+1));
  }
  //cout << "ret =" << ret <<endl;
  return ret;
}

int main(){
  ll x;
  cin >> x;

  ll n=1;
  vector< pair<ll,ll> > ans;

  while(6*x>=2*n*n*n+3*n*n+n){
    ll m=calc_m(n,x);
    if(m!=-1) ans.push_back(make_pair(n,m));
    n++;
  }

  //折り返し
  ll lim=ans.size();
  if(ans[ans.size()-1].first == ans[ans.size()-1].second) lim--;
  for(ll i=lim-1; i>=0; --i){
    ans.push_back(make_pair(ans[i].second,ans[i].first));
  }

  //output answer
  cout << ans.size() << endl;
  for(ll i=0; i<ans.size(); ++i){
    cout << ans[i].first << " " << ans[i].second << endl;
  }

}