裏紙

ほぼ競プロ、たまに日記

CF 811 E - Vladik and Entertaining Flags

問題

Problem - E - Codeforces

問題概要

n \times mのグリッドgがあり、各セルに数字がかかれている。グリッドの美しさを、連結成分の個数と定義する(同じ数字かつ上下左右のいずれかで触れていれば連結とする)。

  • クエリ(l,r): グリッドの左端をl、右端をrで切った時に残る部分の美しさを答えよ。

というクエリがq個与えられるので、それぞれに足して答えよ。

  •  1 \le n \le 10
  •  1 \le m \le 10^5
  •  1 \le g_{i,j} \le 10^6
  •  1 \le q \le 10^5
  •  1 \le l \le r \le m

イデア

SegnentTreeのような考え方で、各segmentに対して次の3つの情報を持たせる:

  • 区間の左端lにおけるグリッドの状態(サイズnの配列)
  • 区間の右端rにおけるグリッドの状態(サイズnの配列)
  • segment内の区間に存在する連結成分の個数

この情報を使うと、隣合うsegment同士のmergeは次のようにできる:

いま、segment A( [l_A , r_A ] , 連結成分の個数CC_A)とsegment B( [l_B , r_B ] , 連結成分の個数CC_B)をmergeするという状況を考える(隣り合うことを仮定するので、 r_A + 1 = l_Bとする)。merge後のsegmentをsegment Cとすると、範囲は当然 [l_A , r_B ] になる。

連結成分の個数についてはCC_A + CC_Bからスタートして、初めは各セルに固有のIDをふっておき、一段ずつ見ていきながらUnionFindっぽくIDを同じものにしていくというのを各段にやると実現できる。別に律儀にUFをする必要はなくて、nは非常に小さいので各段に対してO(n)かけて全体でO(n^2)かかっても間に合う。

以上から、1つのクエリに対してO(n^2 log m)で処理ができることが分かったので、これは十分高速である。

実装(C++)

#include <bits/stdc++.h>
using namespace std;
using ll = long long;
#define rep(i,n) for(int (i)=0;(i)<(int)(n);++(i))
#define all(x) (x).begin(),(x).end()
#define pb push_back
#define fi first
#define se second

const int N = 10, M = 100000;

struct node{
    int n;
    int l[N],r[N];
    int cc = 0;

    node(){}
    node(int _n){
        n = _n;
    }

    // unite ID x & ID y
    void unite(int x, int y){
        --cc;
        rep(i,n){
            if(l[i]==x) l[i] = y;
            if(r[i]==x) r[i] = y;
        }

    }
};


int n,m;
int grid[N][M];

node dat[3*M];

// b.lとposの列が一致
node merge(node a, node b, int pos){
    if(a.cc == 0) return b;
    if(b.cc == 0) return a;

    node ret(n);
    rep(i,n){
        ret.l[i] = a.l[i];
        ret.r[i] = b.r[i];
    }
    ret.cc = a.cc + b.cc;

    rep(i,n){
        if(grid[i][pos-1] == grid[i][pos]){
            int ID_l = a.r[i], ID_r = b.l[i];
            if(ID_l != ID_r){
                ret.unite(ID_l, ID_r);
                a.unite(ID_l, ID_r);
            }
        }
    }

    return ret;
}

void build(int k, int l, int r){
    if(l+1==r){
        dat[k] = node(n);
        rep(i,n) dat[k].l[i] = dat[k].r[i] = l*n+i;
        dat[k].cc = n;

        for(int i=1; i<n; ++i){
            if(grid[i][l] == grid[i-1][l]) dat[k].unite(dat[k].l[i],dat[k].l[i-1]);
        }
        return;
    }

    build(2*k+1, l, (l+r)/2);
    build(2*k+2, (l+r)/2, r);
    dat[k] = merge(dat[2*k+1], dat[2*k+2], (l+r)/2);
}

// [a,b)
node query(int a, int b, int k, int l, int r){
    if(r<=a || b<=l) return node();
    if(a<=l && r<=b) return dat[k];

    node vl = query(a,b,2*k+1,l,(l+r)/2);
    node vr = query(a,b,2*k+2,(l+r)/2,r);
    return merge(vl, vr, (l+r)/2);
}

int main(){
    int q;
    scanf(" %d %d %d", &n, &m, &q);
    rep(i,n)rep(j,m) scanf(" %d", &grid[i][j]);

    build(0,0,m);

    while(q--){
        int ql,qr;
        scanf(" %d %d", &ql, &qr);
        --ql;
        node res = query(ql,qr,0,0,m);
        printf("%d\n", res.cc);
    }
    return 0;
}