裏紙

ほぼ競プロ、たまに日記

CS Academy #41 - Candles

問題

CS Academy

問題概要

n本のろうそくがある。i番目のろうそくの長さはh_iである。一晩ろうそくを使用すると長さは1だけ減る。

これから、m回の夜を過ごそうとしている。i日目の夜には、c_i本のろうそくを灯そうと考えている。最大で過ごせる夜の回数を求めよ。

  •  1 \le n,m \le 10^5
  •  1 \le h_i \le 10^5
  •  1 \le c_i \le 10^5

イデア

できるだけ多くのろうそくを残し続けた方がいいので、その日ごとに長い方からc_i本のろうそくを貪欲に選んで使用するのが最適であるのは分かる。 ただ、それを愚直に処理するのは無理なので、効率よく処理することを考える。

ろうそくを、長さの降順でソートしておけば、使用するろうそくの番号は区間になる。その使用した状況をBITで管理する。その時、1番目からc_i番目を使用することになるのだが、ろうそくを使用後にもソートされた状態を保ちたいので、c_i番目のろうそくの長さに注目し、その長さをxとすると、x+1以上の長さの位置までは素直に先頭から減少させ、xと同じ長さの区間だけはソートされた状態を保つためにできるだけ後ろのものを減少させる。その区間の位置の左端と右端は二分探索によって探せばよい。

実装(C++)

#include <bits/stdc++.h>
using namespace std;
using ll = long long;
#define rep(i,n) for(int (i)=0;(i)<(int)(n);++(i))
#define all(x) (x).begin(),(x).end()
#define pb push_back
#define fi first
#define se second

struct BIT{
    // [1,n]
    int n; vector<ll> bit;
    // 初期化
    BIT(int _n){
        n = _n;
        bit = vector<ll>(n+1,0);
    }
    // sum of [1,i]
    ll sum(int i){
        ll s = 0;
        while(i>0){
            s += bit[i];
            i -= i & -i;
        }
        return s;
    }
    // add x in i-th element
    void add(int i, ll x){
        while(i<=n){
            bit[i] += x;
            i += i & -i;
        }
    }
};

int solve()
{
    int n,m;
    scanf(" %d %d", &n, &m);
    vector<int> h(n+1),c(m);
    rep(i,n) scanf(" %d", &h[i+1]);
    rep(i,m) scanf(" %d", &c[i]);

    h[0]=19191919;
    sort(all(h),greater<int>());

    BIT bit(n+1);
    rep(i,m)
    {
        if(n<c[i]) return i;

        int x = h[c[i]]+bit.sum(c[i]);
        if(x<=0) return i;

        // x+1 と x の境界を見つける(lのindexがx+1以上の値の右端)
        int l=0,r=c[i];
        while(r-l>1)
        {
            int mid = (l+r)/2;
            int tx = h[mid]+bit.sum(mid);
            if(tx>x) l=mid;
            else r=mid;
        }
        bit.add(1,-1);
        bit.add(l+1,1);

        // x と x-1 の境界を見つける(lのindexがxの値の右端)
        int same = c[i]-l;
        l=0,r=n+1;
        while(r-l>1)
        {
            int mid = (l+r)/2;
            int tx = h[mid]+bit.sum(mid);
            if(tx>=x) l=mid;
            else r=mid;
        }
        bit.add(l-same+1,-1);
        bit.add(l+1,1);
    }
    return m;
}

int main()
{
    printf("%d\n", solve());
    return 0;
}