くちもちとくらの重み付き最大マッチング実装日記 - priority queue 2の実装

http://kutimoti.hatenablog.com/entry/2018/11/14/194510

前回です

今回でわかったんですけど,日記なので前回の実装ミスったとか普通にあります

p.q.2 とは

p.q.2(priority queue 2)も重み付き最大マッチングで高速化に使われるものです.p.q.1を使って実装します.

p.q.2ではp.q.1とは違ってグループに要素が優先度付きで挿入されていきます.

それぞれのグループにはactivenonactiveの設定がされていて

p.q.2ではactiveなグループに含まれている要素の中で一番優先度の低いものを答えるクエリを捌きます

操作は以下です.

  1. [insert]要素iを優先度p_iでグループgに挿入する.
  2. [erase]要素iを削除する.
  3. [find]activeなグループに含まれている要素の中で優先度の一番小さい要素を見つける.
  4. [subtract_delta]activeなグループに挿入されているすべての要素の優先度をδだけ小さくする.
  5. [generate new group]新しいグループを作る
  6. [delete group]グループを削除する.
  7. [change active status]グループのactive状態を変更する.
  8. [split]要素iを含んでいるグループをi以下の要素のグループとiより大きいグループで分ける.

p.q.2 の実装方針

まず,グループを追加していくのは面倒臭いので,必要なグループの数をコンストラクタで指定させることにしました.

(5),(6)の削除の操作も処理が多くなるので,(5),(6)を併合してclear()としました.

  1. pqA[g]はグループgに属する要素を管理するp.q.1です.
  2. _isActive[g]はグループgがactiveかどうかをboolで持ったものです.
  3. pqBはそれぞれのactiveなグループからfindした値を入れたものです(つまりこれをfindすれば,p.q.2のfindが達成できます)
  4. delta_last[g]は操作1,2,7,8を行ったときのpqB::delta(論文中ではΔと表現されている)の値です.これを保持することで,操作4によるグループ間の値の変化をうまく処理することが出来ます.(処理しているのはeval_delta)
  5. group[i]は要素iが属しているグループです.
  6. erase_from_BでpqBからグループgの最小の優先度の要素を取り除き,insert_to_BでpqBにグループgの最小の優先度の要素を挿入します.これにより,active状態を変更したときの挙動や,グループにinsert,eraseしたときのpqBに対する処理が出来ます.

注意

この実装のために,前回のsplay_mqpq_1を変更しました.

splay_mq

mq_index()の返す型をKeyからPに変えました(p.q.2で最小の優先度を取得する必要があったため)

pq_1

それに伴ってfind()の返す型を::std::pair<priority_type,element_type>に変えました.

実装

pq_2

#include <vector>

class pq_2{
  using element_type = int;
  using priority_type = long long;
  using group_type = ::std::size_t;
  priority_type delta;

  const ::std::size_t N;
  const ::std::size_t G;

  ::std::vector<pq_1*> pqA;
  ::std::vector<bool> _isActive;
  ::std::vector<priority_type> delta_last;
  ::std::vector<group_type> group;
  pq_1 pqB;
  void eval_delta(group_type g){
    if(_isActive[g]){
      pqA[g]->delta = pqA[g]->delta + pqB.delta - delta_last[g];
    }
    delta_last[g] = pqB.delta;
  }
  void erase_from_B(group_type g){
    if(_isActive[g]){
      auto p = pqA[g]->find();
      if(p.second == -1) return;
      pqB.erase(p.second);
    }
  }
  void insert_to_B(group_type g){
    if(_isActive[g]){
      auto p = pqA[g]->find();
      if(p.second == -1) return;
      pqB.insert(p.second , p.first - pqA[g]->delta);
    }
  }
public:
  pq_2(::std::size_t ele_num , ::std::size_t group_num) : N(ele_num) , G(group_num){
    pqA.assign(G , nullptr);
    _isActive.assign(G , false);
    delta_last.assign(G , 0);
    group.assign(N , -1);
    for(int i = 0;i < G;i++){
      pqA[i] = new pq_1();
    }
  }
  ~pq_2(){
    for(auto & pqa : pqA){
      if(pqa) delete pqa , pqa = nullptr;
    }
  }
  // (1) insert an element i with priority pi to group g
  void insert(group_type g , element_type  i , priority_type pi){
    eval_delta(g);
    erase_from_B(g);
    pqA[g]->insert(i,pi);
    group[i] = g;
    insert_to_B(g);
  }
  // (2) delete an element i
  void erase(element_type i){
    pqA[group[i]]->erase(i);
    group_type g = group[i];
    erase_from_B(g);
    pqB.erase(i);
    group[i] = -1;
    insert_to_B(g);
  }
  // (3) find an active element with the minimal priority
  ::std::pair<priority_type,element_type> find(){
    return pqB.find();
  }
  // (4) decrease the priorityies of all the active elements by some real numbers delta
  void subtract_delta(priority_type d){
    pqB.subtract_delta(d);
  }
  // (5) and (6) delete a group g and generate a new empty group(nonactive)
  void clear(group_type g){
    erase_from_B(g);
    delete pqA[g] , pqA[g] = nullptr;
    pqA[g] = new pq_1();
    _isActive[g] = false;
    delta_last[g] = pqB.delta;
  }
  // (7.1) change the status of a group from nonactive to active
  void activate(group_type g){
    eval_delta(g);
    _isActive[g] = true;
    insert_to_B(g);
  }

  // (7.2) change the status of a group from active to nonactive
  void nonactivate(group_type g){
    eval_delta(g);
    erase_from_B(g);
    _isActive[g] = false;
  }

  // (8) split a group according to an element in it
  void split(element_type i , group_type save_to){
    group_type g = group[i];
    if(g == -1) return;
    eval_delta(g);
    auto p = pqA[g]->tree->split(i);
    delete pqA[g]->tree;
    pqA[g]->tree = p.first;
    pqA[save_to]->tree = p.second;
  }
};

次回

何もわかっていないので僕が論文を読んでからのお楽しみ

くちもちとくらの重み付き最大マッチング実装日記 - priority queue 1の実装

これからこの論文にもとづいて

重み付き最大マッチングをO(EVlogV)で実装するアルゴリズムを書いていきたいと思います.よろしくお願いします.

またkutimoti/MaximalWeightedMatching - GitHubソースコードは見られます.

言語はC++です.

p.q.1 とは

p.q.1(priority queue 1)はこの重み付き最大マッチングで高速化をするために必要なデータ構造です.またp.q.2でも使います.

操作は以下の4つです.

  1. [insert]要素iを優先度p_iで挿入する.
  2. [erase]要素を削除する.
  3. [mq_index]優先度の一番小さい要素を見つける.
  4. [subtract_delta]挿入されているすべての優先度をδだけ小さくする.

p.q.1 の実装方針

splay tree(平衡二分木) でAOJ RMQのように実装します.

insert , eraseの操作は平衡二分木なのでできます.

splay tree の各頂点に部分木の中で最小の優先度p_iと要素iを持つstd::pair<int,int> mq = {p_i , i}をもってやることで,findにも答えられます.

subtract_deltaは,Δ = Σδを持っておくことで,次からのinsertp_i + Δで挿入することで,挿入されている優先度を引き算せずに,優先度を並べることを達成できます.

実装

splay tree

splay tree の実装はSplay tree - Wikipediaが一番わかりやすいです.

#include <algorithm>
#include <set>

class splay_mq{
  using i64 = long long;
  using Key = i64;
  using T = i64;
  using P = ::std::pair<T,Key>;
  const long long INF = 1e18;
  const P PINF = {INF , -1};
  struct node{
    node* left;
    node* right;
    node* parent;
    Key key;
    T value;
    P mq;
    node(const Key& key , const T& val) : left(nullptr) , right(nullptr) , parent(nullptr) , key(key) , value(val) , mq({val , key}){}
    ~node(){
      if(left) delete left , left = nullptr;
      if(right) delete right , right = nullptr;
    }
  };
  node* root;
  T get_value(const node* n){
    if(!n) return INF;
    return n->value;
  }
  P get_mq(const node* n){
    if(!n) return PINF;
    return n->mq;
  }
  node* fix(node * n){
    if(!n) return nullptr;
    n->mq = ::std::min(::std::min(get_mq(n->left) , get_mq(n->right)) , {get_value(n) , n->key});
    return n;
  }
  void set_left(node * par , node * x){
    if(par)
      par->left = x;
    if(x)
      x->parent = par;
    fix(par);
  }
  void set_right(node * par , node * x){
    if(par)
      par->right = x;
    if(x)
      x->parent = par;
    fix(par);
  }
  void zig(node * x){
    node* p = x->parent;
    set_left(p,x->right);
    if(p->parent){
      if(p->parent->left == p)
        set_left(p->parent , x);
      else
        set_right(p->parent , x);
    }
    else{
      x->parent = nullptr;
    }
    set_right(x , p);
  }
  void zag(node * x){
    node* p = x->parent;
    set_right(p,x->left);
    if(p->parent){
      if(p->parent->left == p)
        set_left(p->parent , x);
      else
        set_right(p->parent , x);
    }
    else{
      x->parent = nullptr;
    }
    set_left(x , p);
  }
  node* splay(node* x){
    if(!x) return nullptr;
    while(x->parent){
      if(!x->parent->parent){
        if(x->parent->left == x){
          zig(x);
        }
        else{
          zag(x);
        }
      }
      else if(x->parent->parent->left == x->parent && x->parent->left == x){
        zig(x->parent);
        zig(x);
      }
      else if(x->parent->parent->left == x->parent && x->parent->right == x){
        zag(x->parent);
        zig(x);
      }
      else if(x->parent->parent->right == x->parent && x->parent->right == x){
        zag(x->parent);
        zag(x);
      }
      else{
        zig(x->parent);
        zag(x);
      }
    }
    return root = x;
  }
public:
  splay_mq() : root(nullptr){}
  splay_mq(node * root) : root(root){}
  ~splay_mq(){if(root) delete root , root = nullptr; }
  bool find(const Key& key){
    node * z = root;
    node* p = nullptr;
    while(z){
      p = z;
      if(z->key < key)
        z = z->right;
      else if(key < z->key)
        z = z->left;
      else{
        splay(z);
        return true;
      }
    }
    splay(p);
    return false;
  }
  void insert(Key key , T val){
    if(find(key))
      return;
    node* z = new node(key,val);
    if(!root)
      root = z;
    else if(root->key < key){
      set_right(z , root->right);
      set_right(root , z);
    }
    else{
      set_left(z , root->left);
      set_left(root , z);
    }
    splay(z);
  }
  bool erase(Key key){
    if(!find(key)){
      return false;
    }
    node * z = root;
    if(!z->left && !z->right)
      root = nullptr;
    else if(!z->left){
      root = z->right;
      root->parent = nullptr;
    }
    else if(!z->right){
      root = z->left;
      root->parent = nullptr;
    }
    else{
      node * lm = z->left;
      while(lm->right){
        lm = lm->right;
      }
      z->left->parent = nullptr;
      splay(lm);
      root = lm;
      set_right(root , z->right);
    }
    fix(root);
    return true;
  }
  Key mq_index(){
    return get_mq(root).second;
  }
  ::std::pair<splay_mq*,splay_mq*> split(Key key){
    if(!root)
      return {new splay_mq() , new splay_mq()};
    find(key);
    splay_mq* ngr ,* gr;
    if(root->key <= key){
      root->right->parent = nullptr;
      gr = new splay_mq(root->right);
      root->right = nullptr;
      fix(root);
      ngr = new splay_mq(root);
    }
    else{
      root->left->parent = nullptr;
      ngr = new splay_mq(root->left);
      root->left = nullptr;
      fix(root);
      gr = new splay_mq(root);
    }
    root = nullptr;
    return {ngr , gr};
  }
};

pq_1

class pq_1{
public:
  using element_type = long long;
  using priority_type = long long;
  priority_type delta;
  splay_mq *tree;
  pq_1() : delta(0) , tree(new splay_mq()){}
  pq_1(splay_mq * tr) : delta(0) , tree(tr){}
  ~pq_1(){if(tree) delete tree , tree = nullptr;}

  // (1) insert an element i with priority pi
  void insert(element_type i , priority_type pi){
    tree->insert(i , pi + delta);
  }
  // (2) delete an element i
  void erase(element_type i){
    tree->erase(i);
  }
  // (3) find an element with the minimal priority
  element_type find(){
    return tree->mq_index();
  }
  // (4) subtract from the priorities of all the current elements some real number delta
  void subtract_delta(priority_type d){
    delta += d;
  }
};

次回

p.q.2を実装していきます.

今日から始めるWingBIT - 非再帰SegmentTreeのようなもの

再帰SegmentTreeと一緒にはしたくない()

経緯

トイレでなんとなく思いついたデータ構造がうまくいったので好きになっているだけの話です(論文があったというのは内緒で)

BIT(BinaryIndexedTree)を知っていますか?

このような感じでそれぞれの要素を持ち,[0,k]の和を求めるデータ構造です.

f:id:Kutimoti:20181005223207p:plain

O(logN)区間和が取得できる上,SegmentTreeより定数倍が早い,実装が軽いことで有名ですね.

しかし,,,

BITは和以外では使いみちがあまりわからない

RMQ(Range Minimum Query)などは処理することができません...悲しいね

そこで†WingBIT†なるものを紹介します.

WingBIT

f:id:Kutimoti:20181005223224p:plain

このように今まで空いていたスペースにBITをもう一枚入れる感じのデータ構造です.(名前の由来はここから来ています)

いまから具体的にクエリをどのように捌くか説明します.

点更新 - update

index = 3の場所を更新する例です.

だんだん高さをあげていき,updateをします.

まず,3の真上にあるBlueとRedの場所をマークします.

f:id:Kutimoti:20181005223240p:plain Blue(3)は一番下のノードなので,値を更新して,次のBlueにマークを移動させます.

移動は+depthをするとできます.

f:id:Kutimoti:20181005223254p:plain

次にdepth=2をみるとRed(6)があります. これを下の2つのノードを使って更新します. そして次のRedにマークを移動させます.

f:id:Kutimoti:20181005223308p:plain

depth=4をみるとBlue(4)があります 同じように下の2つのノードを使って更新します. そして次のBlueにマークを移動させます.

f:id:Kutimoti:20181005223319p:plain

これを繰り返し,depthが一番上に行くまで繰り返すことで,更新が出来ます.

区間取得 - get_inter

一例を見てみましょう.

f:id:Kutimoti:20181005223333p:plain

f:id:Kutimoti:20181005223344p:plain

区間は必ず,左側がRed,右側がBlueでできていませんか?これを使って実装していきます.(RedとBlueでWingが出来ています.)

先と同じように下から処理していきます.

左側をRedにマーク,右側をBlueにマークします.

f:id:Kutimoti:20181005223359p:plain

depth=1を見ると,Blue(7)があります.これは区間の一部なので,利用していきます.

次の位置は-depthをすることで移動できます.

f:id:Kutimoti:20181005223404p:plain

depth=2を見ると,Red(6),Blue(6)があります.これも区間の一部なので,利用します.

f:id:Kutimoti:20181005223407p:plain

RedとBlueが交差したので終了です.

f:id:Kutimoti:20181005223410p:plain

実装例

Nim Langで書いています.

shl...左シフト
shr...右シフト
seq...C++のvectorと思っていただければ
import algorithm
type
    WingBIT[Monoid] = object
        rw : seq[Monoid]
        lw : seq[Monoid]
        ide : Monoid
        sz : int
        update_func : proc(node : Monoid , x : Monoid) : Monoid
        f : proc(x : Monoid , y : Monoid) : Monoid
proc newWingBIT*[Monoid](n : int , ide : Monoid,
    update_func : proc(node : Monoid , x : Monoid) : Monoid , 
    f : proc(x : Monoid , y : Monoid) : Monoid) : WingBIT[Monoid] =
    var bit : WingBIT[Monoid]
    bit.rw = @[]
    bit.lw = @[]
    bit.sz = 1
    while bit.sz < n: bit.sz = bit.sz * 2
    bit.rw.setLen(bit.sz + 1)
    bit.lw.setLen(bit.sz + 1)
    bit.rw.fill(ide)
    bit.lw.fill(ide)
    bit.ide = ide
    bit.update_func = update_func
    bit.f = f
    return bit

proc update*[Monoid](bit : var WingBIT[Monoid] , k : int , x : Monoid) =
    var depth = 1
    var right = k
    var left = bit.sz + 1 - k
    if (right and depth) > 0:
        bit.rw[right] = bit.update_func(bit.rw[right],x)
        right = right + depth
    if (left and depth) > 0:
        bit.lw[left] = bit.update_func(bit.lw[left],x)
        left = left + depth
    depth = 2
    while depth <= bit.sz:
        var dd = depth shr 1
        if (left and depth) > 0:
            bit.lw[left] = bit.f(bit.rw[bit.sz - left + dd],bit.lw[left - dd])
            left = left + depth
        if (right and depth) > 0:
            bit.rw[right] = bit.f(bit.rw[right - dd],bit.lw[bit.sz - right + dd])
            right = right + depth
        depth = depth shl 1

proc get_inter*[Monoid](bit : WingBIT[Monoid] , left : int , right : int) : Monoid =
    var al = bit.ide
    var ar = bit.ide
    var depth = 1
    var l = bit.sz + 1 - left
    var r = right
    while bit.sz + 1 - l <= r:
        if (l != bit.sz) and ((l and depth) > 0):
            al = bit.f(al,bit.lw[l])
            l -= depth
        if (r and depth) > 0:
            ar = bit.f(bit.rw[r],ar)
            r -= depth
        depth = depth shl 1
    return bit.f(ar,al)

# verify arc008 - d

import strutils
import sequtils

var temp = stdin.readLine.split.map(parseInt)

var N = temp[0]
var M = temp[1]

var p : array[101010,int64]
var a : array[101010,float64]
var b : array[101010,float64]

var se : seq[int64] = @[]

for i in 0..<M:
    var temp2 = stdin.readLine.split
    p[i] = temp2[0].parseInt
    a[i] = temp2[1].parseFloat
    b[i] = temp2[2].parseFloat
    se.add(p[i])
sort(se,system.cmp)

type TT = tuple[a : float64 , b : float64]
proc ffunc(x : TT,y : TT) : TT=
    return (y.a * x.a , y.a * x.b + y.b)

proc uupdate(x : TT,y : TT) : TT=
    return y

var ran = newWingBIT(M + 1000,(1.0,0.0),uupdate,ffunc)

var mr : float64 = 1.0
var ir : float64 = 1.0
for i in 0..<M:
    p[i] = lowerBound(se,p[i]) + 1
    ran.update((int)p[i],(a[i],b[i]))
    var tup : TT = ran.rw[ran.sz]
    var rrr : float64 = tup.a + tup.b
    mr = max(mr,rrr)
    ir = min(ir,rrr)
echo ir.formatFloat
echo mr.formatFloat

利点

再帰SegmentTreeより早い.(無駄が無い)

実装がそんなに大変なわけでもない

名前がかっこいい

SRM 735 MaxSquare 「よくある二分探索の変なことをやるやつ」

なんやこのテク...っていう気持ちになったので記事を書きます

http://community.topcoder.com/stat?c=problem_statement&pm=14929

問題の解説...https://www.topcoder.com/blog/single-round-match-735-editorials/

りんごさんの参考解説動画(ICPCの類似した問題)

https://youtu.be/agCN6bPxeE4?t=6052

https://youtu.be/agCN6bPxeE4?t=11979

問題の本質

数列Bが与えられる.
このときf(i,j) = (B[j] - B[i]) * (j - i + 1)の最大値を求めよ.

平面で考えてみる

横軸をindex,縦軸をBの値のと置くと、下のような図ができて

f:id:Kutimoti:20180831214244j:plain

この図でf(1,2)は、下の長方形の面積になります.'

f:id:Kutimoti:20180831214304j:plain

これの最大を求めたい...

最適な右上

ある頂点iを決めたとき、f(i,j)が最大になるjを「最適な右上」と呼ぶことにします.

左下になり得る頂点、右上になり得る頂点

こんな感じになりそう(実際なって証明ができるがなんとなくわかる)

f:id:Kutimoti:20180831214333j:plain

最適な右上が単調増加

もしこの図で

f:id:Kutimoti:20180831214358j:plain

頂点0の最適な右上が3

頂点1の最適な右上が2だとすると

f(0,2) < f(0,3) ==> ABCD < CDE ==> AB < E
f(1,3) < f(1,2) ==> DEFG < BDF ==> EG < B

足すと
ABEG < BE ==> AG < 0(は?)

これは矛盾

つまり

ある左下の頂点Lを決めたときの最適な右上をRとすると、Lより左側の頂点を左下としたときの最適な右上はRより左側、右側についても同様が成り立つ

f:id:Kutimoti:20180831214431j:plain

これを使うと,左下の頂点の区間[left,right),右上の頂点の区間を'[lo,hi)'とすると

(下のsolve関数を見たほうがいいかもしれない)

1.mid = (left + right) / 2番目の左下の頂点をLとする.

2.Lの最適な右上を調べ、indexをrとする.

3.答えを更新

4.[left,mid),[lo,r + 1)に分割して1.をする.

5.[mid + 1,right),[r,hi)に分割して1.をする.

これで答えがO(NlogN)で求まります(すごい)

Source

#include <bits/stdc++.h>
using namespace std;
using i64 = long long;
#define rep(i,s,e) for(int (i) = (s);(i) <= (e);(i)++)
#define all(x) x.begin(),x.end()

class MaxSquare{
public:
  using P = pair<i64,i64>;
  vector<P> L,R;

  vector<P> normalize(vector<P> v){
    vector<P> ret;
    for(int i = 0;i < v.size();i++){
      while(ret.size() > 0 && ret[ret.size() - 1].second <= v[i].second){
        ret.pop_back();
      }
      ret.push_back(v[i]);
    }
    return ret;
  }

  vector<P> flip(vector<P> v){
    for(int i = 0;i < v.size();i++){
      v[i].second *= -1;
    }
    reverse(v.begin(),v.end());
    return v;
  }

  i64 solve(int left,int right,int lo,int hi){
    int mid = (left + right) / 2;
    int bestl = -1;
    i64 best = -1;
    for(int i = lo;i < hi;i++){
      i64 cur = (L[mid].first - R[i].first) * (L[mid].second - R[i].second);
      if(cur > best){
        best = cur;
        bestl = i;
      }
    }

    if(left < mid){
      best = max(best,solve(left,mid,lo,bestl + 1));
    }
    if(mid + 1 < right){
      best = max(best,solve(mid + 1,right,bestl,hi));
    }
    return best;
  }

  i64 getMaxSum(i64 n,i64 s,i64 q,i64 o,vector<i64> x,vector<i64> y){

    vector<i64> b(n);
    rep(i,0,n - 1){
      b[i] = (s / (1LL << 20)) % q + o;

      i64 s0 = (s * 621) % (1LL << 51);
      i64 s1 = (s * 825) % (1LL << 51);
      i64 s2 = (s * 494) % (1LL << 51);
      i64 s3 = (s *  23) % (1LL << 51);

      s = s3;
      s = (s * (1LL << 10) + s2) % (1LL << 51);
      s = (s * (1LL << 10) + s1) % (1LL << 51);
      s = (s * (1LL << 10) + s0 + 11) % (1LL << 51);
    }
    for(int i = 0;i < x.size();i++){
      b[x[i]] = y[i];
    }
    auto X = b;
    for(int i = 1;i < n;i++){
      b[i] += b[i - 1];
    }
    vector<P> B;
    B.push_back({-1,0});
    for(int i = 0;i < n;i++){
      B.push_back({i,b[i]});
    }
    L = flip(normalize(flip(B)));
    R = normalize(B);

    i64 ans = 2 * solve(0,L.size(),0,R.size());
    if(ans == 0){
      ans = 2 * X[0];
      for(int i = 0;i < n;i++){
        ans = max(ans , 2 * X[i]);
      }
    }
    return ans;
  }
};

AtCoder Regular Contest 074 E - RGB Sequence

https://beta.atcoder.jp/contests/arc074/tasks/arc074_c

解法

左から色を決めていくとする.

すでに塗った部分の色をすべて覚えておくことはできないので、情報量を落とさなければならない.

どうしよう?

ここで問題の性質である色の種類がxiであるについて考える.

区間に色を含んでいることはどうやってわかるだろう...?

これは,各色の一番右の場所を覚えておけば良い.

なので,i番目のマスを更新する際.i == rとなる与えられている区間[l,r]について,

遷移元の状態がx種類であるかどうかを見てやれば良い.

なので

  dp[i + 1][i + 1][g][b] += dp[i][r][g][b];
  //g,bについても同様に

である.

しかし、これでは間に合わない.

よく見ると,i = max({r,g,b})であるので,遷移を縮めることができる.

これでOK.

#include <bits/stdc++.h>
using namespace std;
using i64 = long long;
#define rep(i,s,e) for(int (i) = (s);(i) <= (e);(i)++)

int N,M;
vector<int> L,R,X;
vector<int> ri[303];

i64 dp[303][303][303];
i64 MOD = 1e9 + 7;
int main(){
  cin >> N >> M;
  L.resize(M);
  R.resize(M);
  X.resize(M);
  rep(i,0,M - 1) cin >> L[i] >> R[i] >> X[i];
  rep(i,0,M - 1) ri[R[i]].push_back(i);
  dp[0][0][0] = 1;

  rep(r,0,N - 1){
    rep(g,0,N - 1){
      rep(b,0,N - 1){
        int next = max({r,g,b}) + 1;
        //red
        {
          bool ok = true;
          int MIN = min({g,b});
          int MAX = max({g,b});
          for(auto idx : ri[next]){
            int cnt = 1;
            if(MAX >= L[idx]) cnt++;
            if(MIN >= L[idx]) cnt++;
            ok = ok && cnt == X[idx];
          }
          if(ok){
            dp[next][g][b] += dp[r][g][b];
            dp[next][g][b] %= MOD;
          }
        }
        {
          bool ok = true;
          int MIN = min({r,b});
          int MAX = max({r,b});
          for(auto idx : ri[next]){
            int cnt = 1;
            if(MAX >= L[idx]) cnt++;
            if(MIN >= L[idx]) cnt++;
            ok = ok && cnt == X[idx];
          }
          if(ok){
            dp[r][next][b] += dp[r][g][b];
            dp[r][next][b] %= MOD;
          }
        }
        {
          bool ok = true;
          int MIN = min({g,r});
          int MAX = max({g,r});
          for(auto idx : ri[next]){
            int cnt = 1;
            if(MAX >= L[idx]) cnt++;
            if(MIN >= L[idx]) cnt++;
            ok = ok && cnt == X[idx];
          }
          if(ok){
            dp[r][g][next] += dp[r][g][b];
            dp[r][g][next] %= MOD;
          }
        }
      }
    }
  }

  i64 ans = 0;
  rep(i,0,N - 1){
    rep(j,0,N - 1){
      ans = (dp[N][i][j] + ans) % MOD;
      ans = (dp[i][N][j] + ans) % MOD;
      ans = (dp[i][j][N] + ans) % MOD;
    }
  }

  cout << ans << endl;
}

AtCoder Regular Contest 100 E - Or Plus Max

問題URL : https://beta.atcoder.jp/contests/arc100/tasks/arc100_c

問題

長さ2^Nの整数列A(0-indexed)がある.

1 <= K <= 2^N - 1を満たすすべての整数 K について, 次の値を求める.

0 <= i < j <= 2^N - 1,(i or j) <= K のとき, A_i + A_jの最大値を求めよ.

問題を見て思ったこと

(i or j) <= K から i < j <= K であることがわかる.

(i or j) <= Kなので,うまくbitDPをしそう...

考察

入力例3について考えてみます.

次のように,最大値を2番目まで持つことでうまくやることができます.

f:id:Kutimoti:20180701233634j:plain

i = 0 のときの MAX1st,MAX2nd を作る.

f:id:Kutimoti:20180701233704j:plain

ここで,最大値と同時にiの値も記録するのが大事です(後で説明します)

K = 1のときの値を求める

f:id:Kutimoti:20180701233735j:plain

K = 1のときは j = 0,1 が選べるので,MAX1st[i]の値は75に,MAX2nd[i]の値は26になります.

ここでMAX1st[i] + MAX2nd[i]を取ると,K = 1のときの答えが得られます.

K = 2のときの値を求める.

f:id:Kutimoti:20180701233809j:plain

K = 2のときは j = 0,2 が選べるので,MAX1st[i]の値は75に,MAX2nd[i]の値は45になります.

ここでMAX1st[i] + MAX2nd[i]を取ると,K = 1のときの答えが得られます.

K = 3のときの値を求める.

f:id:Kutimoti:20180701233924j:plain

ここが大事です.

K = 3のときは j = 0,1,2が選べますが,

j = 0 を見る必要はあるでしょうか?ないですね.

なぜなら,j = 0 の値は j = 1,2 ですでに評価されています(その結果,MAX1stが75になっている)

なので,見るべきは K = 3からビットを一つづつ消した,j = 1,2だけでOKです.

また,青色の矢印がバツされています.なぜでしょう?

これはi < jの条件です.重複を防いでいます.これのためにMAX2ndを用意しています.

これを繰り返すことで (i or j) == K のときの値を計算できます.

(i or j) <= K

(i or j) == K について考えていたので,後は今までの最大値を取るだけでOKです.

source code

resとans要らない(草)

#include <bits/stdc++.h>
using namespace std;
using i64 = long long;
#define rep(i,s,e) for(int (i) = (s);(i) <= (e);(i)++)
#define all(x) x.begin(),x.end()

int N;
vector<i64> A;

vector<i64> ans;
vector<i64> MAX[2];
vector<i64> mi[2];
vector<i64> res;
int main(){
  cin >> N;
  A.resize(1 << N);
  ans.resize(1 << N,0);
  MAX[0].resize(1 << N,0);
  MAX[1].resize(1 << N,0);
  mi[0].resize(1 << N,0);
  mi[1].resize(1 << N,0);
  res.resize(1 << N,0);
  rep(i,0,(1 << N) - 1) cin >> A[i];
  MAX[0][0] = A[0];
  mi[0][0] = 0;
  ans[0] = A[0];
  res[0] = A[0];
  int cnt = 0;
  rep(s,1,(1 << N) - 1){
    vector<pair<i64,i64>> vec;
    vec.push_back({A[s],s});
    rep(i,0,N){
      if(s & (1 << i)){
        vec.push_back({MAX[0][s & ~(1 << i)],mi[0][s & ~(1 << i)]});
        vec.push_back({MAX[1][s & ~(1 << i)],mi[1][s & ~(1 << i)]});
      }
    }
    sort(all(vec));
    //重複削除
    vec.erase(unique(all(vec)),vec.end());
    //最大値なので
    reverse(all(vec));
    if(vec.size() >= 2){
      MAX[0][s] = vec[0].first;
      mi[0][s] = vec[0].second;
      MAX[1][s] = vec[1].first;
      mi[1][s] = vec[1].second;
      ans[s] = MAX[0][s] + MAX[1][s];
    }
    //これ多分要らない
    else{
      MAX[0][s] = vec[0].first;
      mi[0][s] = vec[0].second;
      ans[s] = MAX[0][s];
    }
    res[s] = max(ans[s],res[s - 1]);
    cout << res[s] << endl;
  }
}

うまく説明するのが難しいなぁ...

Codeforces 1000B Light It Up

http://codeforces.com/problemset/problem/1000/C

問題概要

時間 0,a1,a2,...,aN,M のタイミングで明かりがON,OFFを繰り返す.

ここで最大1個,タイミングを増やすことができる.

明かりがついている時間を最大化せよ.

解法

onの区間にXを挿入する

a_i-2 [off] a_i-1 [on] a_i [off] a_i+1

のような区間

a_i-2 [off] a_i-1 [on] X [off] a_i [on] a_i+1

とすることを考えると,このときの明かりがついている時間の合計は,Xを挿入する前の状態を考えて

(a[i]より左側のonの時間) - 1 + (a[i]より右側のoffの時間)

となる.なぜなら,なるべくonの区間を長くしようとするとX = (a_i) - 1となるからである

offの区間にXを挿入する

a_i-2 [on] a_i-1 [off] a_i [on] a_i+1

のような区間

a_i-2 [on] a_i-1 [off] X [on] a_i [off] a_i+1

とすることを考えると,このときの明かりがついている時間の合計は,Xを挿入する前の状態を考えて

(a[i - 1]より左側のonの時間) + (a[i - 1]より右側のoffの時間) - 1

となる.なぜなら,なるべくonの区間を長くしようとするとX = (a_i - 1) - 1となるからである

Xの挿入する区間がonかoffかはindexの偶奇によって決まるので,それを実装してやれば良い.

source code

#include <bits/stdc++.h>
using namespace std;
using i64 = long long;
#define rep(i,s,e) for(int (i) = (s);(i) <= (e);(i)++)

int N;
i64 M;
vector<i64> A;

int main(){
  cin >> N >> M;
  A.resize(N + 2);
  A[0] = 0;
  rep(i,1,N) cin >> A[i];
  A[N + 1] = M;

  vector<i64> on;
  vector<i64> off;
  on.push_back(0);
  off.push_back(0);
  rep(i,1,N + 1){
    if(i % 2 == 1){
      on.push_back(A[i] - A[i - 1] + on.back());
      off.push_back(off.back());
    }
    else{
      off.push_back(A[i] - A[i - 1] + off.back());
      on.push_back(on.back());
    }
  }

  i64 ans = on.back();

  rep(i,1,N){
    if(i % 2 == 1){
      ans = max(ans , on[i] - 1 + off.back() - off[i]);
    }
    else{
      ans = max(ans , off.back() - off[i - 1] - 1 + on[i]);
    }
  }
  cout << ans << endl;
}

こういう問題をきれいに早く実装したいなぁ...