AtCoder Regular Contest 100 E - Or Plus Max

問題URL : https://beta.atcoder.jp/contests/arc100/tasks/arc100_c

問題

長さ2^Nの整数列A(0-indexed)がある.

1 <= K <= 2^N - 1を満たすすべての整数 K について, 次の値を求める.

0 <= i < j <= 2^N - 1,(i or j) <= K のとき, A_i + A_jの最大値を求めよ.

問題を見て思ったこと

(i or j) <= K から i < j <= K であることがわかる.

(i or j) <= Kなので,うまくbitDPをしそう...

考察

入力例3について考えてみます.

次のように,最大値を2番目まで持つことでうまくやることができます.

f:id:Kutimoti:20180701233634j:plain

i = 0 のときの MAX1st,MAX2nd を作る.

f:id:Kutimoti:20180701233704j:plain

ここで,最大値と同時にiの値も記録するのが大事です(後で説明します)

K = 1のときの値を求める

f:id:Kutimoti:20180701233735j:plain

K = 1のときは j = 0,1 が選べるので,MAX1st[i]の値は75に,MAX2nd[i]の値は26になります.

ここでMAX1st[i] + MAX2nd[i]を取ると,K = 1のときの答えが得られます.

K = 2のときの値を求める.

f:id:Kutimoti:20180701233809j:plain

K = 2のときは j = 0,2 が選べるので,MAX1st[i]の値は75に,MAX2nd[i]の値は45になります.

ここでMAX1st[i] + MAX2nd[i]を取ると,K = 1のときの答えが得られます.

K = 3のときの値を求める.

f:id:Kutimoti:20180701233924j:plain

ここが大事です.

K = 3のときは j = 0,1,2が選べますが,

j = 0 を見る必要はあるでしょうか?ないですね.

なぜなら,j = 0 の値は j = 1,2 ですでに評価されています(その結果,MAX1stが75になっている)

なので,見るべきは K = 3からビットを一つづつ消した,j = 1,2だけでOKです.

また,青色の矢印がバツされています.なぜでしょう?

これはi < jの条件です.重複を防いでいます.これのためにMAX2ndを用意しています.

これを繰り返すことで (i or j) == K のときの値を計算できます.

(i or j) <= K

(i or j) == K について考えていたので,後は今までの最大値を取るだけでOKです.

source code

resとans要らない(草)

#include <bits/stdc++.h>
using namespace std;
using i64 = long long;
#define rep(i,s,e) for(int (i) = (s);(i) <= (e);(i)++)
#define all(x) x.begin(),x.end()

int N;
vector<i64> A;

vector<i64> ans;
vector<i64> MAX[2];
vector<i64> mi[2];
vector<i64> res;
int main(){
  cin >> N;
  A.resize(1 << N);
  ans.resize(1 << N,0);
  MAX[0].resize(1 << N,0);
  MAX[1].resize(1 << N,0);
  mi[0].resize(1 << N,0);
  mi[1].resize(1 << N,0);
  res.resize(1 << N,0);
  rep(i,0,(1 << N) - 1) cin >> A[i];
  MAX[0][0] = A[0];
  mi[0][0] = 0;
  ans[0] = A[0];
  res[0] = A[0];
  int cnt = 0;
  rep(s,1,(1 << N) - 1){
    vector<pair<i64,i64>> vec;
    vec.push_back({A[s],s});
    rep(i,0,N){
      if(s & (1 << i)){
        vec.push_back({MAX[0][s & ~(1 << i)],mi[0][s & ~(1 << i)]});
        vec.push_back({MAX[1][s & ~(1 << i)],mi[1][s & ~(1 << i)]});
      }
    }
    sort(all(vec));
    //重複削除
    vec.erase(unique(all(vec)),vec.end());
    //最大値なので
    reverse(all(vec));
    if(vec.size() >= 2){
      MAX[0][s] = vec[0].first;
      mi[0][s] = vec[0].second;
      MAX[1][s] = vec[1].first;
      mi[1][s] = vec[1].second;
      ans[s] = MAX[0][s] + MAX[1][s];
    }
    //これ多分要らない
    else{
      MAX[0][s] = vec[0].first;
      mi[0][s] = vec[0].second;
      ans[s] = MAX[0][s];
    }
    res[s] = max(ans[s],res[s - 1]);
    cout << res[s] << endl;
  }
}

うまく説明するのが難しいなぁ...