tkmst201's Library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub tkmst201/Library

:heavy_check_mark: Test/AVL_Tree.2.test.cpp

Depends on

Code

#define PROBLEM "https://judge.yosupo.jp/problem/associative_array"

#include "DataStructure/AVL_Tree.hpp"

#include <cstdio>
#include <cassert>
#include <algorithm>
#include <set>
#include <numeric>

int main() {
	int Q;
	scanf("%d", &Q);
	
	using ll = long long;
	using pll = std::pair<ll, ll>;
	
	AVL_Tree<pll> avl;
	
	while (Q--) {
		int q;
		ll k, v;
		scanf("%d %lld", &q, &k);
		auto it = avl.lower_bound({k, 0});
		
		if (q == 0) {
			scanf("%lld", &v);
			if (it != avl.end() && it->val.first == k) avl.erase(it);
			avl.insert({k, v});
		}
		else {
			printf("%lld\n", it == avl.end() || it->val.first != k ? 0 : it->val.second);
		}
	}
	
	// [begin, next] test

	auto elem = avl.enumerate();
	{
		std::vector<pll> elem2;
		for (auto it = avl.begin(); it != avl.end(); it = avl.next(it)) elem2.emplace_back(it->val);
		assert(elem == elem2);
	}
	// [end, prev] test

	{
		std::vector<pll> elem2;
		for (auto it = avl.prev(avl.end()); it != avl.end(); it = avl.prev(it)) elem2.emplace_back(it->val);
		std::reverse(begin(elem2), end(elem2));
		assert(elem == elem2);
	}
	
	// [*_bound, count*, k_th_*] test

	for (int i = 0; i < elem.size(); ++i) {
		const auto & e = elem[i];
		auto itelemup = std::upper_bound(begin(elem), end(elem), e);
		auto itelemlo = std::lower_bound(begin(elem), end(elem), e);
		assert(avl.lower_bound(e) != avl.end());
		assert(avl.lower_bound(e)->val == e);
		assert(itelemup == elem.end() ? avl.upper_bound(e) == avl.end() : avl.upper_bound(e)->val == *itelemup);
		
		const auto lt = itelemlo - elem.begin(), eq = itelemup - itelemlo, gt = elem.end() - itelemup;
		assert(avl.count_less_than(e) == lt);
		assert(avl.count_less_equal(e) == lt + eq);
		assert(avl.count_greater_than(e) == gt);
		assert(avl.count_greater_equal(e) == gt + eq);
		assert(avl.count(e) == eq);
		assert(avl.k_th_smallest(i + 1)->val == e);
		assert(avl.k_th_largest(elem.size() - i)->val == e);
	}
	
	// [erase, size] test

	std::set<pll> ss;
	for (const auto & e : elem) ss.insert(e);
	std::vector<int> ord(elem.size());
	std::iota(begin(ord), end(ord), 0);
	std::sort(begin(ord), end(ord), [&](int i, int j) {
		if (elem[i].second == elem[j].second) return elem[i].first > elem[j].first;
		return elem[i].second < elem[j].second;
	});
	assert(ss.size() == avl.size());
	for (int i : ord) {
		auto it = avl.erase(elem[i]);
		auto its = ss.find(elem[i]); ++its;
		assert(its == ss.end() ? it == avl.end() : it->val == *its);
		ss.erase(--its);
		assert(ss.size() == avl.size());
	}
}
#line 1 "Test/AVL_Tree.2.test.cpp"
#define PROBLEM "https://judge.yosupo.jp/problem/associative_array"

#line 1 "DataStructure/AVL_Tree.hpp"



#include <algorithm>
#include <cstdint>
#include <vector>
#include <utility>
#include <stack>

/**
 * @brief https://tkmst201.github.io/Library/DataStructure/AVL_Tree.hpp
 */
template<typename T>
struct AVL_Tree {
	using size_type = std::size_t;
	using value_type = T;
	using const_reference = const value_type &;
	
private:
	using uint32 = std::uint32_t;
	using int8 = std::int8_t;
	
public:
	struct Node;
	using node_ptr = Node *;
	using const_ptr = const Node * const;
	struct Node {
		value_type val;
		node_ptr par, child[2] {nullptr, nullptr};
		bool isr;
		int8 height[2] {};
		uint32 size[2] {};
		Node(const_reference val, node_ptr par, bool isr) : val(val), par(par), isr(isr) {}
	};
	
private:
	size_type n = 0;
	node_ptr root = nullptr;
	node_ptr e_ptr[2] {nullptr, nullptr};
	
public:
	AVL_Tree() = default;
	
	AVL_Tree(const AVL_Tree & rhs) {
		*this = rhs;
	}
	
	AVL_Tree(AVL_Tree && rhs) {
		*this = std::forward<AVL_Tree>(rhs);
	}
	
	~AVL_Tree() {
		clear();
	}
	
	AVL_Tree & operator =(const AVL_Tree & rhs) {
		if (this != &rhs) {
			clear();
			auto dfs = [](auto self, const_ptr q, node_ptr r) -> node_ptr {
				if (!q) return nullptr;
				node_ptr res = new Node(q->val, r, q->isr);
				for (int i = 0; i < 2; ++i) {
					res->height[i] = q->height[i];
					res->size[i] = q->size[i];
					res->child[i] = self(self, q->child[i], res);
				}
				return res;
			};
			root = dfs(dfs, rhs.root, nullptr);
			n = rhs.n;
			e_ptr[0] = e_ptr[1] = root;
			if (root) for (int i = 0; i < 2; ++i) while (e_ptr[i]->child[i]) e_ptr[i] = e_ptr[i]->child[i];
		}
		return *this;
	}
	
	AVL_Tree & operator =(AVL_Tree && rhs) {
		if (this != &rhs) {
			clear();
			n = rhs.n;
			rhs.n = 0;
			root = rhs.root;
			rhs.root = nullptr;
			std::copy(rhs.e_ptr, rhs.e_ptr + 2, e_ptr);
			std::fill(rhs.e_ptr, rhs.e_ptr + 2, nullptr);
		}
		return *this;
	}
	
	bool empty() const noexcept {
		return size() == 0;
	}
	
	size_type size() const noexcept {
		return n;
	}
	
	void clear() {
		if (!root) return;
		std::stack<node_ptr> stk;
		stk.emplace(root);
		while (!stk.empty()) {
			node_ptr node = stk.top();
			stk.pop();
			if (node->child[0]) stk.emplace(node->child[0]);
			if (node->child[1]) stk.emplace(node->child[1]);
			delete node;
		}
		n = 0;
		root = nullptr;
		std::fill(e_ptr, e_ptr + 2, nullptr);
	}
	
	std::vector<value_type> enumerate() const {
		std::vector<value_type> elements;
		elements.reserve(size());
		auto dfs = [&elements](auto self, const_ptr q) -> void {
			if (!q) return;
			self(self, q->child[0]);
			elements.emplace_back(q->val);
			self(self, q->child[1]);
		};
		dfs(dfs, root);
		return elements;
	}
	
	node_ptr begin() const noexcept {
		return e_ptr[0];
	}
	
	node_ptr end() const noexcept {
		return nullptr;
	}
	
	node_ptr insert(const_reference x) {
		node_ptr q = root, r = nullptr;
		bool ef[2] {}, d = false;
		while (q) {
			r = q;
			d = q->val <= x;
			q = q->child[d];
			ef[!d] = true;
		}
		q = new Node(x, r, d);
		++n;
		if (!ef[0]) e_ptr[0] = q;
		if (!ef[1]) e_ptr[1] = q;
		if (r) {
			r->size[d] = 1;
			r->height[d] = 1;
			r->child[d] = q;
			update(r);
		}
		else root = q;
		return q;
	}
	
	node_ptr erase(const_reference x) noexcept {
		node_ptr q = find(x);
		if (q == end()) return end();
		return erase(q);
	}
	
	node_ptr erase(node_ptr q) noexcept {
		if (!q) return end();
		const node_ptr ret = next(q);
		if (q->child[0] && q->child[1]) {
			node_ptr p = q->child[0];
			while (p->child[1]) p = p->child[1];
			q->val = std::move(p->val);
			q = p;
		}
		if (e_ptr[0] == q) e_ptr[0] = next(q);
		if (e_ptr[1] == q) e_ptr[1] = prev(q);
		const node_ptr r = q->par;
		if (q->child[0] || q->child[1]) {
			const node_ptr p = q->child[0] ? q->child[0] : q->child[1];
			if (r) {
				r->size[q->isr] = q->size[p->isr];
				r->height[q->isr] = q->height[p->isr];
				r->child[q->isr] = p;
				p->par = r;
				p->isr = q->isr;
			}
			else {
				p->par = nullptr;
				root = p;
			}
		}
		else if (r) {
			r->size[q->isr] = 0;
			r->height[q->isr] = 0;
			r->child[q->isr] = nullptr;
		}
		else root = nullptr;
		delete q;
		--n;
		if (r) update(r);
		return ret;
	}
	
	node_ptr find(const_reference x) const noexcept {
		const node_ptr q = lower_bound(x);
		if (q != end() && q->val != x) return end();
		return q;
	}
	
	node_ptr lower_bound(const_reference x) const noexcept {
		node_ptr q = root;
		if (!q) return end();
		while (q->child[q->val < x]) q = q->child[q->val < x];
		if (q->val < x) q = next(q);
		return q;
	}
	
	node_ptr upper_bound(const_reference x) const noexcept {
		node_ptr q = root;
		if (!q) return end();
		while (q->child[q->val <= x]) q = q->child[q->val <= x];
		if (q->val <= x) q = next(q);
		return q;
	}
	
	size_type count_less_than(const_reference x) const noexcept {
		size_type res = 0;
		node_ptr q = root;
		while (q != nullptr) {
			bool r = q->val < x;
			if (r) res += q->size[0] + 1;
			q = q->child[r];
		}
		return res;
	}
	
	size_type count_less_equal(const_reference x) const noexcept {
		size_type res = 0;
		node_ptr q = root;
		while (q != nullptr) {
			bool r = q->val <= x;
			if (r) res += q->size[0] + 1;
			q = q->child[r];
		}
		return res;
	}
	
	size_type count_greater_than(const_reference x) const noexcept {
		return size() - count_less_equal(x);
	}
	
	size_type count_greater_equal(const_reference x) const noexcept {
		return size() - count_less_than(x);
	}
	
	size_type count(const_reference x) const noexcept {
		return count_less_equal(x) - count_less_than(x);
	}
	
	node_ptr k_th_smallest(uint32 k) const noexcept {
		if (k == 0 || n < k) return end();
		node_ptr q = root;
		while (k != q->size[0] + 1) {
			if (k > q->size[0] + 1) k -= q->size[0] + 1, q = q->child[1];
			else q = q->child[0];
		}
		return q;
	}
	
	node_ptr k_th_largest(uint32 k) const noexcept {
		if (k == 0 || n < k) return end();
		return k_th_smallest(n - k + 1);
	}
	
	node_ptr next(node_ptr q) const noexcept {
		return move(q, true);
	}
	
	node_ptr prev(node_ptr q) const noexcept {
		return move(q, false);
	}
	
private:
	node_ptr rotate(node_ptr q, bool d) noexcept {
		node_ptr r = q->par, p = q->child[!d], b = p->child[d];
		(r ? r->child[q->isr] : root) = p;
		q->child[!d] = b;
		p->child[d] = q;
		if (b) {
			b->par = q;
			b->isr = !d;
		}
		p->par = r;
		p->isr = q->isr;
		q->par = p;
		q->isr = d;
		q->size[!d] = p->size[d];
		q->height[!d] = p->height[d];
		p->size[d] = q->size[0] + q->size[1] + 1;
		p->height[d] = std::max(q->height[0], q->height[1]) + 1;
		return p;
	}
	
	void update(node_ptr q) noexcept {
		bool done = false;
		while (true) {
			if (!done && std::abs(q->height[0] - q->height[1]) > 1) {
				const bool d = q->height[0] > q->height[1];
				const node_ptr p = q->child[!d];
				if (p->height[!d] < p->height[d]) rotate(p, !d);
				q = rotate(q, d);
				done = true;
			}
			const node_ptr r = q->par;
			if (!r) break;
			r->size[q->isr] = q->size[0] + q->size[1] + 1;
			r->height[q->isr] = std::max(q->height[0], q->height[1]) + 1;
			q = r;
		}
	}
	
	node_ptr move(node_ptr q, bool d) const noexcept {
		if (q == end()) return e_ptr[!d];
		if (q == begin() && !d) return end();
		if (q->child[d]) for (q = q->child[d]; q->child[!d]; q = q->child[!d]);
		else {
			while (q && (d ^ !q->isr)) q = q->par;
			if (q) q = q->par;
		}
		return q;
	}
};


#line 4 "Test/AVL_Tree.2.test.cpp"

#include <cstdio>
#include <cassert>
#line 8 "Test/AVL_Tree.2.test.cpp"
#include <set>
#include <numeric>

int main() {
	int Q;
	scanf("%d", &Q);
	
	using ll = long long;
	using pll = std::pair<ll, ll>;
	
	AVL_Tree<pll> avl;
	
	while (Q--) {
		int q;
		ll k, v;
		scanf("%d %lld", &q, &k);
		auto it = avl.lower_bound({k, 0});
		
		if (q == 0) {
			scanf("%lld", &v);
			if (it != avl.end() && it->val.first == k) avl.erase(it);
			avl.insert({k, v});
		}
		else {
			printf("%lld\n", it == avl.end() || it->val.first != k ? 0 : it->val.second);
		}
	}
	
	// [begin, next] test

	auto elem = avl.enumerate();
	{
		std::vector<pll> elem2;
		for (auto it = avl.begin(); it != avl.end(); it = avl.next(it)) elem2.emplace_back(it->val);
		assert(elem == elem2);
	}
	// [end, prev] test

	{
		std::vector<pll> elem2;
		for (auto it = avl.prev(avl.end()); it != avl.end(); it = avl.prev(it)) elem2.emplace_back(it->val);
		std::reverse(begin(elem2), end(elem2));
		assert(elem == elem2);
	}
	
	// [*_bound, count*, k_th_*] test

	for (int i = 0; i < elem.size(); ++i) {
		const auto & e = elem[i];
		auto itelemup = std::upper_bound(begin(elem), end(elem), e);
		auto itelemlo = std::lower_bound(begin(elem), end(elem), e);
		assert(avl.lower_bound(e) != avl.end());
		assert(avl.lower_bound(e)->val == e);
		assert(itelemup == elem.end() ? avl.upper_bound(e) == avl.end() : avl.upper_bound(e)->val == *itelemup);
		
		const auto lt = itelemlo - elem.begin(), eq = itelemup - itelemlo, gt = elem.end() - itelemup;
		assert(avl.count_less_than(e) == lt);
		assert(avl.count_less_equal(e) == lt + eq);
		assert(avl.count_greater_than(e) == gt);
		assert(avl.count_greater_equal(e) == gt + eq);
		assert(avl.count(e) == eq);
		assert(avl.k_th_smallest(i + 1)->val == e);
		assert(avl.k_th_largest(elem.size() - i)->val == e);
	}
	
	// [erase, size] test

	std::set<pll> ss;
	for (const auto & e : elem) ss.insert(e);
	std::vector<int> ord(elem.size());
	std::iota(begin(ord), end(ord), 0);
	std::sort(begin(ord), end(ord), [&](int i, int j) {
		if (elem[i].second == elem[j].second) return elem[i].first > elem[j].first;
		return elem[i].second < elem[j].second;
	});
	assert(ss.size() == avl.size());
	for (int i : ord) {
		auto it = avl.erase(elem[i]);
		auto its = ss.find(elem[i]); ++its;
		assert(its == ss.end() ? it == avl.end() : it->val == *its);
		ss.erase(--its);
		assert(ss.size() == avl.size());
	}
}
Back to top page