You are given a tree of $n$ vertices. You are to select $k$ (not necessarily distinct) simple paths in such a way that it is possible to split all edges of the tree into three sets: edges not contained in any path, edges that are a part of exactly one of these paths, and edges that are parts of all selected paths, and the latter set should be non-empty.
Compute the number of ways to select $k$ paths modulo $998244353$.
The paths are enumerated, in other words, two ways are considered distinct if there are such $i$ ($1 \leq i \leq k$) and an edge that the $i$-th path contains the edge in one way and does not contain it in the other.
Input
The first line contains two integers $n$ and $k$ ($1 \leq n, k \leq 10^{5}$) — the number of vertices in the tree and the desired number of paths.
The next $n – 1$ lines describe edges of the tree. Each line contains two integers $a$ and $b$ ($1 \le a, b \le n$, $a \ne b$) — the endpoints of an edge. It is guaranteed that the given edges form a tree.
Output
Print the number of ways to select $k$ enumerated not necessarily distinct simple paths in such a way that for each edge either it is not contained in any path, or it is contained in exactly one path, or it is contained in all $k$ paths, and the intersection of all paths is non-empty.
As the answer can be large, print it modulo $998244353$.
Examples
input
3 2
1 2
2 3
output
7
input
5 1
4 1
2 3
4 5
2 1
output
10
input
29 29
1 2
1 3
1 4
1 5
5 6
5 7
5 8
8 9
8 10
8 11
11 12
11 13
11 14
14 15
14 16
14 17
17 18
17 19
17 20
20 21
20 22
20 23
23 24
23 25
23 26
26 27
26 28
26 29
output
125580756
Note
In the first example the following ways are valid:
- $((1,2), (1,2))$,
- $((1,2), (1,3))$,
- $((1,3), (1,2))$,
- $((1,3), (1,3))$,
- $((1,3), (2,3))$,
- $((2,3), (1,3))$,
- $((2,3), (2,3))$.
In the second example $k=1$, so all $n \cdot (n – 1) / 2 = 5 \cdot 4 / 2 = 10$ paths are valid.
In the third example, the answer is $\geq 998244353$, so it was taken modulo $998244353$, don’t forget it!
Solution:
#include <bits/stdc++.h> using namespace std; string to_string(string s) { return '"' + s + '"'; } string to_string(const char* s) { return to_string((string) s); } string to_string(bool b) { return (b ? "true" : "false"); } template <typename A, typename B> string to_string(pair<A, B> p) { return "(" + to_string(p.first) + ", " + to_string(p.second) + ")"; } template <typename A> string to_string(A v) { bool first = true; string res = "{"; for (const auto &x : v) { if (!first) { res += ", "; } first = false; res += to_string(x); } res += "}"; return res; } void debug_out() { cerr << endl; } template <typename Head, typename... Tail> void debug_out(Head H, Tail... T) { cerr << " " << to_string(H); debug_out(T...); } #ifdef LOCAL #define debug(...) cerr << "[" << #__VA_ARGS__ << "]:", debug_out(__VA_ARGS__) #else #define debug(...) 42 #endif const int md = 998244353; inline void add(int &a, int b) { a += b; if (a >= md) a -= md; } inline void sub(int &a, int b) { a -= b; if (a < 0) a += md; } inline int mul(int a, int b) { #if !defined(_WIN32) || defined(_WIN64) return (int) ((long long) a * b % md); #endif unsigned long long x = (long long) a * b; unsigned xh = (unsigned) (x >> 32), xl = (unsigned) x, d, m; asm( "divl %4; \n\t" : "=a" (d), "=d" (m) : "d" (xh), "a" (xl), "r" (md) ); return m; } inline int power(int a, long long b) { int res = 1; while (b > 0) { if (b & 1) { res = mul(res, a); } a = mul(a, a); b >>= 1; } return res; } inline int inv(int a) { a %= md; if (a < 0) a += md; int b = md, u = 0, v = 1; while (a) { int t = b / a; b -= t * a; swap(a, b); u -= t * v; swap(u, v); } assert(b == 1); if (u < 0) u += md; return u; } namespace ntt { int base = 1; vector<int> roots = {0, 1}; vector<int> rev = {0, 1}; int max_base = -1; int root = -1; void init() { int tmp = md - 1; max_base = 0; while (tmp % 2 == 0) { tmp /= 2; max_base++; } root = 2; while (true) { if (power(root, 1 << max_base) == 1) { if (power(root, 1 << (max_base - 1)) != 1) { break; } } root++; } } void ensure_base(int nbase) { if (max_base == -1) { init(); } if (nbase <= base) { return; } assert(nbase <= max_base); rev.resize(1 << nbase); for (int i = 0; i < (1 << nbase); i++) { rev[i] = (rev[i >> 1] >> 1) + ((i & 1) << (nbase - 1)); } roots.resize(1 << nbase); while (base < nbase) { int z = power(root, 1 << (max_base - 1 - base)); for (int i = 1 << (base - 1); i < (1 << base); i++) { roots[i << 1] = roots[i]; roots[(i << 1) + 1] = mul(roots[i], z); } base++; } } void fft(vector<int> &a) { int n = (int) a.size(); assert((n & (n - 1)) == 0); int zeros = __builtin_ctz(n); ensure_base(zeros); int shift = base - zeros; for (int i = 0; i < n; i++) { if (i < (rev[i] >> shift)) { swap(a[i], a[rev[i] >> shift]); } } for (int k = 1; k < n; k <<= 1) { for (int i = 0; i < n; i += 2 * k) { for (int j = 0; j < k; j++) { int x = a[i + j]; int y = mul(a[i + j + k], roots[j + k]); a[i + j] = x + y - md; if (a[i + j] < 0) a[i + j] += md; a[i + j + k] = x - y + md; if (a[i + j + k] >= md) a[i + j + k] -= md; } } } } vector<int> multiply(vector<int> a, vector<int> b, int eq = 0) { int need = (int) (a.size() + b.size() - 1); int nbase = 0; while ((1 << nbase) < need) nbase++; ensure_base(nbase); int sz = 1 << nbase; a.resize(sz); b.resize(sz); fft(a); if (eq) b = a; else fft(b); int inv_sz = inv(sz); for (int i = 0; i < sz; i++) { a[i] = mul(mul(a[i], b[i]), inv_sz); } reverse(a.begin() + 1, a.end()); fft(a); a.resize(need); return a; } vector<int> square(vector<int> a) { return multiply(a, a, 1); } } vector<vector<int>> vect; vector<int> multiply_all(int from, int to) { if (from == to) { return {1}; } if (from + 1 == to) { return vect[from]; } int mid = (from + to) >> 1; return ntt::multiply(multiply_all(from, mid), multiply_all(mid, to)); } const int N = 400010; vector<int> g[N]; int sz[N]; int pv[N]; int val[N]; int sum_val[N]; int memo[N]; int touched[N], ITER; vector<int> all; int n, k; int ans; void dfs(int v, int pr) { pv[v] = pr; all.push_back(v); sz[v] = 1; sum_val[v] = 0; vector<int> children; for (int u : g[v]) { if (u == pr) { continue; } children.push_back(u); dfs(u, v); sz[v] += sz[u]; add(ans, mul(sum_val[u], sum_val[v])); add(sum_val[v], sum_val[u]); } vect.clear(); for (int u : children) { vect.push_back({1, sz[u]}); } vector<int> res = multiply_all(0, (int) vect.size()); assert(res.size() == vect.size() + 1); int ways = 1; val[v] = 0; for (int i = 0; i <= min(k, (int) vect.size()); i++) { add(val[v], mul(ways, res[i])); ways = mul(ways, k - i); } add(sum_val[v], val[v]); if (pr != -1) { res = ntt::multiply(res, {1, n - sz[v]}); } ITER++; for (int it = 0; it < (int) children.size(); it++) { int S = sz[children[it]]; int &vall = memo[S]; if (touched[S] != ITER) { touched[S] = ITER; vall = 0; int cur = 0; ways = 1; for (int i = 0; i <= min(k, (int) res.size() - 1); i++) { cur = mul(cur, S); cur = (res[i] - cur + md) % md; add(vall, mul(ways, cur)); ways = mul(ways, k - i); } assert(cur == 0); } add(ans, mul(vall, sum_val[children[it]])); } } int main() { ios::sync_with_stdio(false); cin.tie(0); cin >> n >> k; for (int i = 0; i < n; i++) { g[i].clear(); } for (int i = 0; i < n - 1; i++) { int x, y; cin >> x >> y; x--; y--; g[x].push_back(y); g[y].push_back(x); } for (int i = 0; i <= n; i++) { memo[i] = -1; touched[i] = -1; } ITER = 0; ans = 0; if (k == 1) { ans = mul(mul(n, n - 1), inv(2)); } else { dfs(0, -1); } cout << ans << '\n'; return 0; }