K Paths

You are given a tree of $n$ vertices. You are to select $k$ (not necessarily distinct) simple paths in such a way that it is possible to split all edges of the tree into three sets: edges not contained in any path, edges that are a part of exactly one of these paths, and edges that are parts of all selected paths, and the latter set should be non-empty.

Compute the number of ways to select $k$ paths modulo $998244353$.

The paths are enumerated, in other words, two ways are considered distinct if there are such $i$ ($1 \leq i \leq k$) and an edge that the $i$-th path contains the edge in one way and does not contain it in the other.


The first line contains two integers $n$ and $k$ ($1 \leq n, k \leq 10^{5}$) — the number of vertices in the tree and the desired number of paths.

The next $n – 1$ lines describe edges of the tree. Each line contains two integers $a$ and $b$ ($1 \le a, b \le n$, $a \ne b$) — the endpoints of an edge. It is guaranteed that the given edges form a tree.


Print the number of ways to select $k$ enumerated not necessarily distinct simple paths in such a way that for each edge either it is not contained in any path, or it is contained in exactly one path, or it is contained in all $k$ paths, and the intersection of all paths is non-empty.

As the answer can be large, print it modulo $998244353$.



3 2
1 2
2 3




5 1
4 1
2 3
4 5
2 1




29 29
1 2
1 3
1 4
1 5
5 6
5 7
5 8
8 9
8 10
8 11
11 12
11 13
11 14
14 15
14 16
14 17
17 18
17 19
17 20
20 21
20 22
20 23
23 24
23 25
23 26
26 27
26 28
26 29




In the first example the following ways are valid:

  • $((1,2), (1,2))$,
  • $((1,2), (1,3))$,
  • $((1,3), (1,2))$,
  • $((1,3), (1,3))$,
  • $((1,3), (2,3))$,
  • $((2,3), (1,3))$,
  • $((2,3), (2,3))$.

In the second example $k=1$, so all $n \cdot (n – 1) / 2 = 5 \cdot 4 / 2 = 10$ paths are valid.

In the third example, the answer is $\geq 998244353$, so it was taken modulo $998244353$, don’t forget it!


#include <bits/stdc++.h>

using namespace std;

string to_string(string s) {
return '"' + s + '"';

string to_string(const char* s) {
return to_string((string) s);

string to_string(bool b) {
return (b ? "true" : "false");

template <typename A, typename B>
string to_string(pair<A, B> p) {
return "(" + to_string(p.first) + ", " + to_string(p.second) + ")";

template <typename A>
string to_string(A v) {
bool first = true;
string res = "{";
for (const auto &x : v) {
if (!first) {
res += ", ";
first = false;
res += to_string(x);
res += "}";
return res;

void debug_out() { cerr << endl; }

template <typename Head, typename... Tail>
void debug_out(Head H, Tail... T) {
cerr << " " << to_string(H);

#ifdef LOCAL
#define debug(...) cerr << "[" << #__VA_ARGS__ << "]:", debug_out(__VA_ARGS__)
#define debug(...) 42

const int md = 998244353;

inline void add(int &a, int b) {
  a += b;
  if (a >= md) a -= md;

inline void sub(int &a, int b) {
a -= b;
if (a < 0) a += md;

inline int mul(int a, int b) {
#if !defined(_WIN32) || defined(_WIN64)
  return (int) ((long long) a * b % md);
  unsigned long long x = (long long) a * b;
  unsigned xh = (unsigned) (x >> 32), xl = (unsigned) x, d, m;
"divl %4; \n\t"
: "=a" (d), "=d" (m)
: "d" (xh), "a" (xl), "r" (md)
return m;

inline int power(int a, long long b) {
int res = 1;
while (b > 0) {
if (b & 1) {
res = mul(res, a);
a = mul(a, a);
b >>= 1;
return res;

inline int inv(int a) {
a %= md;
if (a < 0) a += md;
  int b = md, u = 0, v = 1;
  while (a) {
    int t = b / a;
    b -= t * a; swap(a, b);
    u -= t * v; swap(u, v);
  assert(b == 1);
  if (u < 0) u += md;
  return u;

namespace ntt {
  int base = 1;
  vector<int> roots = {0, 1};
vector<int> rev = {0, 1};
int max_base = -1;
int root = -1;

void init() {
int tmp = md - 1;
max_base = 0;
while (tmp % 2 == 0) {
tmp /= 2;
root = 2;
while (true) {
if (power(root, 1 << max_base) == 1) {
        if (power(root, 1 << (max_base - 1)) != 1) {

  void ensure_base(int nbase) {
    if (max_base == -1) {
    if (nbase <= base) {
    assert(nbase <= max_base);
    rev.resize(1 << nbase);
    for (int i = 0; i < (1 << nbase); i++) {
      rev[i] = (rev[i >> 1] >> 1) + ((i & 1) << (nbase - 1));
    roots.resize(1 << nbase);
    while (base < nbase) {
      int z = power(root, 1 << (max_base - 1 - base));
      for (int i = 1 << (base - 1); i < (1 << base); i++) {
        roots[i << 1] = roots[i];
        roots[(i << 1) + 1] = mul(roots[i], z);

  void fft(vector<int> &a) {
int n = (int) a.size();
assert((n & (n - 1)) == 0);
int zeros = __builtin_ctz(n);
int shift = base - zeros;
for (int i = 0; i < n; i++) {
      if (i < (rev[i] >> shift)) {
swap(a[i], a[rev[i] >> shift]);
for (int k = 1; k < n; k <<= 1) {
      for (int i = 0; i < n; i += 2 * k) {
        for (int j = 0; j < k; j++) {
          int x = a[i + j];
          int y = mul(a[i + j + k], roots[j + k]);
          a[i + j] = x + y - md;
          if (a[i + j] < 0) a[i + j] += md;
          a[i + j + k] = x - y + md;
          if (a[i + j + k] >= md) a[i + j + k] -= md;

vector<int> multiply(vector<int> a, vector<int> b, int eq = 0) {
int need = (int) (a.size() + b.size() - 1);
int nbase = 0;
while ((1 << nbase) < need) nbase++;
    int sz = 1 << nbase;
    if (eq) b = a; else fft(b);
    int inv_sz = inv(sz);
    for (int i = 0; i < sz; i++) {
      a[i] = mul(mul(a[i], b[i]), inv_sz);
    reverse(a.begin() + 1, a.end());
    return a;

  vector<int> square(vector<int> a) {
return multiply(a, a, 1);

vector<vector<int>> vect;

vector<int> multiply_all(int from, int to) {
if (from == to) {
return {1};
if (from + 1 == to) {
return vect[from];
int mid = (from + to) >> 1;
return ntt::multiply(multiply_all(from, mid), multiply_all(mid, to));

const int N = 400010;

vector<int> g[N];
int sz[N];
int pv[N];
int val[N];
int sum_val[N];
int memo[N];
int touched[N], ITER;
vector<int> all;
int n, k;
int ans;

void dfs(int v, int pr) {
pv[v] = pr;
sz[v] = 1;
sum_val[v] = 0;
vector<int> children;
for (int u : g[v]) {
if (u == pr) {
dfs(u, v);
sz[v] += sz[u];
add(ans, mul(sum_val[u], sum_val[v]));
add(sum_val[v], sum_val[u]);
for (int u : children) {
vect.push_back({1, sz[u]});
vector<int> res = multiply_all(0, (int) vect.size());
assert(res.size() == vect.size() + 1);
int ways = 1;
val[v] = 0;
for (int i = 0; i <= min(k, (int) vect.size()); i++) {
    add(val[v], mul(ways, res[i]));
    ways = mul(ways, k - i);
  add(sum_val[v], val[v]);
  if (pr != -1) {
    res = ntt::multiply(res, {1, n - sz[v]});
  for (int it = 0; it < (int) children.size(); it++) {
    int S = sz[children[it]];
    int &vall = memo[S];
    if (touched[S] != ITER) {
      touched[S] = ITER;
      vall = 0;
      int cur = 0;
      ways = 1;
      for (int i = 0; i <= min(k, (int) res.size() - 1); i++) {
        cur = mul(cur, S);
        cur = (res[i] - cur + md) % md;
        add(vall, mul(ways, cur));
        ways = mul(ways, k - i);
      assert(cur == 0);
    add(ans, mul(vall, sum_val[children[it]]));

int main() {
  cin >> n >> k;
for (int i = 0; i < n; i++) {
  for (int i = 0; i < n - 1; i++) {
    int x, y;
    cin >> x >> y;
x--; y--;
for (int i = 0; i <= n; i++) {
    memo[i] = -1;
    touched[i] = -1;
  ITER = 0;
  ans = 0;
  if (k == 1) {
    ans = mul(mul(n, n - 1), inv(2));
  } else {
    dfs(0, -1);
  cout << ans << '\n';
  return 0;