くちもちとくらの重み付き最大マッチング実装日記 - priority queue 1の実装

これからこの論文にもとづいて

重み付き最大マッチングをO(EVlogV)で実装するアルゴリズムを書いていきたいと思います.よろしくお願いします.

またkutimoti/MaximalWeightedMatching - GitHubソースコードは見られます.

言語はC++です.

p.q.1 とは

p.q.1(priority queue 1)はこの重み付き最大マッチングで高速化をするために必要なデータ構造です.またp.q.2でも使います.

操作は以下の4つです.

  1. [insert]要素iを優先度p_iで挿入する.
  2. [erase]要素を削除する.
  3. [mq_index]優先度の一番小さい要素を見つける.
  4. [subtract_delta]挿入されているすべての優先度をδだけ小さくする.

p.q.1 の実装方針

splay tree(平衡二分木) でAOJ RMQのように実装します.

insert , eraseの操作は平衡二分木なのでできます.

splay tree の各頂点に部分木の中で最小の優先度p_iと要素iを持つstd::pair<int,int> mq = {p_i , i}をもってやることで,findにも答えられます.

subtract_deltaは,Δ = Σδを持っておくことで,次からのinsertp_i + Δで挿入することで,挿入されている優先度を引き算せずに,優先度を並べることを達成できます.

実装

splay tree

splay tree の実装はSplay tree - Wikipediaが一番わかりやすいです.

#include <algorithm>
#include <set>

class splay_mq{
  using i64 = long long;
  using Key = i64;
  using T = i64;
  using P = ::std::pair<T,Key>;
  const long long INF = 1e18;
  const P PINF = {INF , -1};
  struct node{
    node* left;
    node* right;
    node* parent;
    Key key;
    T value;
    P mq;
    node(const Key& key , const T& val) : left(nullptr) , right(nullptr) , parent(nullptr) , key(key) , value(val) , mq({val , key}){}
    ~node(){
      if(left) delete left , left = nullptr;
      if(right) delete right , right = nullptr;
    }
  };
  node* root;
  T get_value(const node* n){
    if(!n) return INF;
    return n->value;
  }
  P get_mq(const node* n){
    if(!n) return PINF;
    return n->mq;
  }
  node* fix(node * n){
    if(!n) return nullptr;
    n->mq = ::std::min(::std::min(get_mq(n->left) , get_mq(n->right)) , {get_value(n) , n->key});
    return n;
  }
  void set_left(node * par , node * x){
    if(par)
      par->left = x;
    if(x)
      x->parent = par;
    fix(par);
  }
  void set_right(node * par , node * x){
    if(par)
      par->right = x;
    if(x)
      x->parent = par;
    fix(par);
  }
  void zig(node * x){
    node* p = x->parent;
    set_left(p,x->right);
    if(p->parent){
      if(p->parent->left == p)
        set_left(p->parent , x);
      else
        set_right(p->parent , x);
    }
    else{
      x->parent = nullptr;
    }
    set_right(x , p);
  }
  void zag(node * x){
    node* p = x->parent;
    set_right(p,x->left);
    if(p->parent){
      if(p->parent->left == p)
        set_left(p->parent , x);
      else
        set_right(p->parent , x);
    }
    else{
      x->parent = nullptr;
    }
    set_left(x , p);
  }
  node* splay(node* x){
    if(!x) return nullptr;
    while(x->parent){
      if(!x->parent->parent){
        if(x->parent->left == x){
          zig(x);
        }
        else{
          zag(x);
        }
      }
      else if(x->parent->parent->left == x->parent && x->parent->left == x){
        zig(x->parent);
        zig(x);
      }
      else if(x->parent->parent->left == x->parent && x->parent->right == x){
        zag(x->parent);
        zig(x);
      }
      else if(x->parent->parent->right == x->parent && x->parent->right == x){
        zag(x->parent);
        zag(x);
      }
      else{
        zig(x->parent);
        zag(x);
      }
    }
    return root = x;
  }
public:
  splay_mq() : root(nullptr){}
  splay_mq(node * root) : root(root){}
  ~splay_mq(){if(root) delete root , root = nullptr; }
  bool find(const Key& key){
    node * z = root;
    node* p = nullptr;
    while(z){
      p = z;
      if(z->key < key)
        z = z->right;
      else if(key < z->key)
        z = z->left;
      else{
        splay(z);
        return true;
      }
    }
    splay(p);
    return false;
  }
  void insert(Key key , T val){
    if(find(key))
      return;
    node* z = new node(key,val);
    if(!root)
      root = z;
    else if(root->key < key){
      set_right(z , root->right);
      set_right(root , z);
    }
    else{
      set_left(z , root->left);
      set_left(root , z);
    }
    splay(z);
  }
  bool erase(Key key){
    if(!find(key)){
      return false;
    }
    node * z = root;
    if(!z->left && !z->right)
      root = nullptr;
    else if(!z->left){
      root = z->right;
      root->parent = nullptr;
    }
    else if(!z->right){
      root = z->left;
      root->parent = nullptr;
    }
    else{
      node * lm = z->left;
      while(lm->right){
        lm = lm->right;
      }
      z->left->parent = nullptr;
      splay(lm);
      root = lm;
      set_right(root , z->right);
    }
    fix(root);
    return true;
  }
  Key mq_index(){
    return get_mq(root).second;
  }
  ::std::pair<splay_mq*,splay_mq*> split(Key key){
    if(!root)
      return {new splay_mq() , new splay_mq()};
    find(key);
    splay_mq* ngr ,* gr;
    if(root->key <= key){
      root->right->parent = nullptr;
      gr = new splay_mq(root->right);
      root->right = nullptr;
      fix(root);
      ngr = new splay_mq(root);
    }
    else{
      root->left->parent = nullptr;
      ngr = new splay_mq(root->left);
      root->left = nullptr;
      fix(root);
      gr = new splay_mq(root);
    }
    root = nullptr;
    return {ngr , gr};
  }
};

pq_1

class pq_1{
public:
  using element_type = long long;
  using priority_type = long long;
  priority_type delta;
  splay_mq *tree;
  pq_1() : delta(0) , tree(new splay_mq()){}
  pq_1(splay_mq * tr) : delta(0) , tree(tr){}
  ~pq_1(){if(tree) delete tree , tree = nullptr;}

  // (1) insert an element i with priority pi
  void insert(element_type i , priority_type pi){
    tree->insert(i , pi + delta);
  }
  // (2) delete an element i
  void erase(element_type i){
    tree->erase(i);
  }
  // (3) find an element with the minimal priority
  element_type find(){
    return tree->mq_index();
  }
  // (4) subtract from the priorities of all the current elements some real number delta
  void subtract_delta(priority_type d){
    delta += d;
  }
};

次回

p.q.2を実装していきます.