tkmst201's Library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub tkmst201/Library

:heavy_check_mark: Test/VertexUpdatePathFold.vertex.test.cpp

Depends on

Code

#define PROBLEM "https://judge.yosupo.jp/problem/vertex_set_path_composite"

#include "Mathematics/ModInt.hpp"
#include "GraphTheory/VertexUpdatePathFold.hpp"

#include <cstdio>
#include <vector>
#include <utility>

int main() {
	int N, Q;
	scanf("%d %d", &N, &Q);
	
	using mint = ModInt<998244353>;
	using P = std::pair<mint, mint>;
	using VUPF = VertexUpdatePathFold<P>;
	
	std::vector<P> init(N);
	for (int i = 0; i < N; ++i) {
		int a, b;
		scanf("%d %d", &a, &b);
		init[i] = {a, b};
	}
	
	VUPF::Graph g(N);
	for (int i = 0; i < N - 1; ++i) {
		int u, v;
		scanf("%d %d", &u, &v);
		g[u].emplace_back(v);
		g[v].emplace_back(u);
	}
	
	VUPF vupf(g, 0, init, true, {1, 0}, [](auto && x, auto && y) -> P { return {x.first * y.first, x.second * y.first + y.second}; });
	while (Q--) {
		int q;
		scanf("%d", &q);
		if (q == 0) {
			int p, c, d;
			scanf("%d %d %d", &p, &c, &d);
			vupf.set(p, {c, d});
		}
		else {
			int u, v, x;
			scanf("%d %d %d", &u, &v, &x);
			auto res = vupf.fold(u, v);
			printf("%d\n", (res.first * x + res.second).val());
		}
	}
}
#line 1 "Test/VertexUpdatePathFold.vertex.test.cpp"
#define PROBLEM "https://judge.yosupo.jp/problem/vertex_set_path_composite"

#line 1 "Mathematics/ModInt.hpp"



#include <cassert>
#include <iostream>
#include <cstdint>

/**
 * @brief https://tkmst201.github.io/Library/Mathematics/ModInt.hpp
 */
template<int M>
struct ModInt {
	static_assert(M > 0);
	
	using value_type = int;
	using calc_type = std::int_fast64_t;
	
private:
	value_type val_;
	
public:
	constexpr ModInt(calc_type val = 0) : val_(val % M + (val >= 0 ? 0 : M)) {}
	constexpr value_type val() const noexcept { return val_; }
	constexpr static decltype(M) mod() noexcept { return M; }
	
	explicit constexpr operator bool() const noexcept { return val_; }
	constexpr bool operator !() const noexcept { return !static_cast<bool>(*this); }
	constexpr ModInt operator +() const noexcept { return *this; }
	constexpr ModInt operator -() const noexcept { return ModInt(val_ == 0 ? 0 : M - val_); }
	constexpr ModInt operator ++(int) noexcept { ModInt res = *this; ++*this; return res; }
	constexpr ModInt operator --(int) noexcept { ModInt res = *this; --*this; return res; }
	constexpr ModInt & operator ++() noexcept { val_ = val_ + 1 == M ? 0 : val_ + 1; return *this; }
	constexpr ModInt & operator --() noexcept { val_ = val_ == 0 ? M - 1 : val_ - 1; return *this; }
	constexpr ModInt & operator +=(const ModInt & rhs) noexcept { val_ += val_ < M - rhs.val_ ? rhs.val_ : rhs.val_ - M; return *this; }
	constexpr ModInt & operator -=(const ModInt & rhs) noexcept { val_ += val_ >= rhs.val_ ? -rhs.val_ : M - rhs.val_; return *this; }
	constexpr ModInt & operator *=(const ModInt & rhs) noexcept { val_ = static_cast<calc_type>(val_) * rhs.val_ % M; return *this; }
	constexpr ModInt & operator /=(const ModInt & rhs) noexcept { return *this *= rhs.inv(); }
	friend constexpr ModInt operator +(const ModInt & lhs, const ModInt & rhs) noexcept { return ModInt(lhs) += rhs; }
	friend constexpr ModInt operator -(const ModInt & lhs, const ModInt & rhs) noexcept { return ModInt(lhs) -= rhs; }
	friend constexpr ModInt operator *(const ModInt & lhs, const ModInt & rhs) noexcept { return ModInt(lhs) *= rhs; }
	friend constexpr ModInt operator /(const ModInt & lhs, const ModInt & rhs) noexcept { return ModInt(lhs) /= rhs; }
	friend constexpr bool operator ==(const ModInt & lhs, const ModInt & rhs) noexcept { return lhs.val_ == rhs.val_; }
	friend constexpr bool operator !=(const ModInt & lhs, const ModInt & rhs) noexcept { return !(lhs == rhs); }
	friend std::ostream & operator <<(std::ostream & os, const ModInt & rhs) { return os << rhs.val_; }
	friend std::istream & operator >>(std::istream & is, ModInt & rhs) { calc_type x; is >> x; rhs = ModInt(x); return is; }
	
	constexpr ModInt pow(calc_type n) const noexcept {
		ModInt res = 1, x = val_;
		if (n < 0) { x = x.inv(); n = -n; }
		while (n) { if (n & 1) res *= x; x *= x; n >>= 1; }
		return res;
	}
	
	constexpr ModInt inv() const noexcept {
		value_type a = val_, a1 = 1, b = M, b1 = 0;
		while (b > 0) {
			const value_type q = a / b;
			value_type tmp = a - q * b; a = b; b = tmp;
			tmp = a1 - q * b1; a1 = b1; b1 = tmp;
		}
		assert(a == 1);
		if (a1 < 0) a1 += M;
		return a1;
	}
};


#line 1 "GraphTheory/VertexUpdatePathFold.hpp"



#include <algorithm>
#line 6 "GraphTheory/VertexUpdatePathFold.hpp"
#include <functional>
#include <stack>
#include <tuple>
#include <utility>
#include <vector>

/**
 * @brief https://tkmst201.github.io/Library/GraphTheory/VertexUpdatePathFold.hpp
 */
template<typename T>
struct VertexUpdatePathFold {
	using value_type = T;
	using const_reference = const value_type &;
	using Graph = std::vector<std::vector<int>>;
	using F = std::function<value_type (const_reference, const_reference)>;
	
private:
	using int3 = std::tuple<int, int, int>;
	struct Node {
		value_type val, rval;
		int par;
		int dep; // 16bits(seg depth) 16bits(heavy-path depth)

		int childs[2];
		Node() = default;
		Node(value_type val) : val(val), rval(val), par(-1), childs{-1, -1} {};
		Node(int par, int seg_dep, int heavy_dep) : par(par), dep(comp_dep(seg_dep, heavy_dep)) {}
		Node(int par, int seg_dep, int heavy_dep, int lcld, int rcld)
			: par(par), dep(comp_dep(seg_dep, heavy_dep)), childs{lcld, rcld} {}
		static int comp_dep(int seg_dep, int heavy_dep) noexcept {
			return seg_dep << 16 | heavy_dep;
		}
		std::pair<int, int> decomp_dep() const noexcept {
			return {dep >> 16, dep & ((1 << 16) - 1)};
		}
		void set_dep(int seg_dep, int heavy_dep) noexcept {
			dep = comp_dep(seg_dep, heavy_dep);
		}
		void set_prop(int par, int seg_dep, int heavy_dep) noexcept {
			this->par = par;
			set_dep(seg_dep, heavy_dep);
		}
	};
	
	int n;
	const bool VERTEX;
	value_type id_elem;
	F f;
	std::vector<Node> nodes;
	std::vector<int> par_, depth_, heavy;
	
	void node_calc(int u) {
		nodes[u].val = f(nodes[nodes[u].childs[0]].val, nodes[nodes[u].childs[1]].val);
		nodes[u].rval = f(nodes[nodes[u].childs[1]].rval, nodes[nodes[u].childs[0]].rval);
	}
	
public:
	VertexUpdatePathFold(const Graph &g, int root, bool VERTEX, const_reference d_elem, const F &f)
		: n(g.size()), VERTEX(VERTEX), id_elem(id_elem), f(f) {
		nodes.reserve(2 * n);
		nodes.assign(n, {id_elem});
		build(g, root);
	}
	
	template<typename U>
	VertexUpdatePathFold(const Graph &g, int root, const std::vector<U> &dat, bool VERTEX, const_reference id_elem, const F &f)
		: n(g.size()), VERTEX(VERTEX), id_elem(id_elem), f(f) {
		nodes.reserve(2 * n);
		for (int i = 0; i < n; ++i) nodes.emplace_back(dat[i]);
		build(g, root);
	}
	
private:
	void build(const Graph &g, int root) {
		par_.assign(n, -1);
		depth_.assign(n, 0);
		heavy.assign(n, -1);
		std::vector<int> siz(n, 0);
		std::stack<std::pair<int, int>> stk;
		stk.emplace(root, 0);
		siz[root] = 0;
		while (stk.size()) {
			auto [u, idx] = stk.top(); stk.pop();
			if (idx == g[u].size()) {
				if (par_[u] != -1) siz[par_[u]] += ++siz[u];
			}
			else {
				stk.emplace(u, idx + 1);
				const int v = g[u][idx];
				if (v == par_[u]) continue;
				stk.emplace(v, 0);
				par_[v] = u;
				depth_[v] = depth_[u] + 1;
			}
		}
		int heavy_num = 0;
		std::vector<int> heavy_stk;
		heavy_stk.emplace_back(-(root + 1));
		nodes[root].dep = 0;
		stk.emplace(root, 0);
		while (stk.size()) {
			auto [u, idx] = stk.top(); stk.pop();
			if (idx == g[u].size()) {
				if (nodes[u].par != -1) {
					const int v = g[u][nodes[u].par];
					heavy_stk.emplace_back(v);
					nodes[v].dep = nodes[u].dep;
					stk.emplace(v, 0);
					continue;
				}
				if (g[u].size() == 0) siz[u] = 1;
				int st = static_cast<int>(heavy_stk.size()) - 1;
				while (heavy_stk[st] >= 0) --st;
				heavy_stk[st] = -heavy_stk[st] - 1;
				const int rootnum = n << 1;
				const int heavy_par = st == 0 ? rootnum : (heavy_stk[st - 1] >= 0 ? heavy_stk[st - 1] : -heavy_stk[st - 1] - 1);
				const int hs = static_cast<int>(heavy_stk.size()) - st;
				const int heavy_dep = nodes[heavy_stk[st]].dep;
				for (int i = 0; i < hs; ++i) heavy[heavy_stk[st + i]] = heavy_num;
				++heavy_num;
				std::vector<int> sum(hs + 1);
				sum[0] = 0;
				for (int i = 0; i < hs; ++i) sum[i + 1] = sum[i] + siz[heavy_stk[st + i]];
				std::stack<int3> merge_stk;
				merge_stk.emplace(0, hs, heavy_par);
				while (merge_stk.size()) {
					auto [l, r, p] = merge_stk.top(); merge_stk.pop();
					if (l == -1) { node_calc(p); continue; }
					const bool rchild = p < 0;
					if (p < 0) p = -p - 1;
					const int seg_dep = (p == heavy_par) ? 0 : nodes[p].decomp_dep().first + 1;
					auto merge = [&](int u, int v, int p, bool rchild, int seg_dep) -> int {
						const int self = nodes.size();
						nodes.emplace_back(p, seg_dep, heavy_dep, u, v);
						if (u != -1) nodes[u].set_prop(self, seg_dep + 1, heavy_dep);
						if (v != -1) nodes[v].set_prop(self, seg_dep + 1, heavy_dep);
						if (p != heavy_par) nodes[p].childs[rchild] = self;
						if (u != -1 && v != -1) node_calc(self);
						else merge_stk.emplace(-1, -1, self);
						return self;
					};
					if (r - l <= 2) {
						if (r - l == 1) {
							const int v = heavy_stk[l + st];
							nodes[v].set_prop(p, seg_dep, heavy_dep);
							if (p != heavy_par) nodes[p].childs[rchild] = v;
						}
						else merge(heavy_stk[l + st], heavy_stk[l + st + 1], p, rchild, seg_dep);
						continue;
					}
					const int m = lower_bound(sum.begin() + l + 1, sum.begin() + r + 1, (sum[r] + sum[l]) >> 1) - sum.begin() - 1;
					const int v = heavy_stk[m + st];
					const int top = nodes.size();
					if (m == l) {
						merge(v, -1, p, rchild, seg_dep);
						merge_stk.emplace(m + 1, r, -top - 1);
					}
					else if (m + 1 == r) {
						merge(-1, v, p, rchild, seg_dep);
						merge_stk.emplace(l, m, top);
					}
					else {
						if (sum[m] - sum[l] < sum[r] - sum[m + 1]) {
							merge(-1, -1, p, rchild, seg_dep);
							merge_stk.emplace(m + 1, r, -top - 1);
							merge(-1, v, top, false, seg_dep + 1);
							merge_stk.emplace(l, m, top + 1);
						}
						else {
							merge(-1, -1, p, rchild, seg_dep);
							merge_stk.emplace(l, m, top);
							merge(v, -1, top, true, seg_dep + 1);
							merge_stk.emplace(m + 1, r, -(top + 1) - 1);
						}
					}
				}
				while (heavy_stk.size() > st) heavy_stk.pop_back();
				if (heavy_par != rootnum) siz[heavy_par] += hs;
			}
			else {
				if (idx == 0) {
					siz[u] = 1;
					int mxc = 0;
					for (int i = 0; i < static_cast<int>(g[u].size()); ++i) {
						const int v = g[u][i];
						if (v != par_[u] && mxc < siz[v]) {
							nodes[u].par = i;
							mxc = siz[v];
						}
					}
				}
				stk.emplace(u, idx + 1);
				const int v = g[u][idx];
				if (v == par_[u] || idx == nodes[u].par) continue;
				heavy_stk.emplace_back(-(v + 1));
				nodes[v].dep = nodes[u].dep + 1;
				stk.emplace(v, 0);
			}
		}
	}

public:
	int size() const noexcept {
		return n;
	}
	
	int par(int v) const noexcept {
		assert(0 <= v && v < n);
		return par_[v];
	}
	
	int depth(int v) const noexcept {
		assert(0 <= v && v < n);
		return depth_[v];
	}
	
	void set(int v, const_reference x) noexcept {
		assert(VERTEX);
		assert(0 <= v && v < n);
		set_impl(v, x);
	}
	
	value_type get(int v) const noexcept {
		assert(VERTEX);
		assert(0 <= v && v < n);
		return get_impl(v);
	}
	
	void set(int u, int v, const_reference x) noexcept {
		assert(!VERTEX);
		assert(0 <= u && u < n);
		assert(0 <= v && v < n);
		set_impl(par_[u] == v ? u : v, x);
	}
	
	value_type get(int u, int v) const noexcept {
		assert(!VERTEX);
		assert(0 <= u && u < n);
		assert(0 <= v && v < n);
		return get_impl(par_[u] == v ? u : v);
	}
	
	value_type fold(int u, int v) const noexcept {
		assert(0 <= u && u < n);
		assert(0 <= v && v < n);
		value_type lv = id_elem, rv = id_elem;
		auto uup = [&](int step = -1, bool lret = true) {
			if (step == -1) {
				lv = f(lv, nodes[u].rval);
				step = nodes[u].decomp_dep().first + 1;
			}
			while (step--) {
				const int p = nodes[u].par;
				if (lret && nodes[p].childs[1] == u) lv = f(lv, nodes[nodes[p].childs[0]].rval);
				if (!lret && nodes[p].childs[0] == u) lv = f(lv, nodes[nodes[p].childs[1]].val);
				u = p;
			}
		};
		auto vup = [&](int step = -1, bool lret = true) {
			if (step == -1) {
				rv = f(nodes[v].val, rv);
				step = nodes[v].decomp_dep().first + 1;
			}
			while(step--) {
				const int p = nodes[v].par;
				if (lret && nodes[p].childs[1] == v) rv = f(nodes[nodes[p].childs[0]].val, rv);
				if (!lret && nodes[p].childs[0] == v) rv = f(nodes[nodes[p].childs[1]].rval, rv);
				v = p;
			}
		};
		while (nodes[u].decomp_dep().second > nodes[v].decomp_dep().second) uup();
		while (nodes[u].decomp_dep().second < nodes[v].decomp_dep().second) vup();
		while (heavy[u] != heavy[v]) uup(), vup();
		bool lright = depth_[u] > depth_[v];
		if (u == v) return VERTEX ? f(f(lv, nodes[u].val), rv) : f(lv, rv);
		if (VERTEX || lright) lv = f(lv, nodes[u].val);
		if (VERTEX || !lright) rv = f(nodes[v].val, rv);
		const int udep = nodes[u].decomp_dep().first, vdep = nodes[v].decomp_dep().first;
		if (udep > vdep) uup(udep - vdep, lright);
		if (udep < vdep) vup(vdep - udep, !lright);
		while (nodes[u].par != nodes[v].par) uup(1, lright), vup(1, !lright);
		return f(lv, rv);
	}
	
private:
	void set_impl(int v, const_reference x) noexcept {
		nodes[v].val = x;
		nodes[v].rval = x;
		for (int i = nodes[v].decomp_dep().first; i > 0; --i) node_calc(v = nodes[v].par);
	}
	
	value_type get_impl(int v) const noexcept {
		return nodes[v].val;
	}
};


#line 5 "Test/VertexUpdatePathFold.vertex.test.cpp"

#include <cstdio>
#line 9 "Test/VertexUpdatePathFold.vertex.test.cpp"

int main() {
	int N, Q;
	scanf("%d %d", &N, &Q);
	
	using mint = ModInt<998244353>;
	using P = std::pair<mint, mint>;
	using VUPF = VertexUpdatePathFold<P>;
	
	std::vector<P> init(N);
	for (int i = 0; i < N; ++i) {
		int a, b;
		scanf("%d %d", &a, &b);
		init[i] = {a, b};
	}
	
	VUPF::Graph g(N);
	for (int i = 0; i < N - 1; ++i) {
		int u, v;
		scanf("%d %d", &u, &v);
		g[u].emplace_back(v);
		g[v].emplace_back(u);
	}
	
	VUPF vupf(g, 0, init, true, {1, 0}, [](auto && x, auto && y) -> P { return {x.first * y.first, x.second * y.first + y.second}; });
	while (Q--) {
		int q;
		scanf("%d", &q);
		if (q == 0) {
			int p, c, d;
			scanf("%d %d %d", &p, &c, &d);
			vupf.set(p, {c, d});
		}
		else {
			int u, v, x;
			scanf("%d %d %d", &u, &v, &x);
			auto res = vupf.fold(u, v);
			printf("%d\n", (res.first * x + res.second).val());
		}
	}
}
Back to top page