You have n friends and you want to take m pictures of them. Exactly two of your friends should appear in each picture and no two pictures should contain the same pair of your friends. So if you have n = 3 friends you can take 3 different pictures, each containing a pair of your friends.
Each of your friends has an attractiveness level which is specified by the integer number a i for the i-th friend. You know that the attractiveness of a picture containing the i-th and the j-th friends is equal to the exclusive-or ( xor operation) of integers a i and a j.
You want to take pictures in a way that the total sum of attractiveness of your pictures is maximized. You have to calculate this value. Since the result may not fit in a 32-bit integer number, print it modulo 1000000007 (109 + 7).
Input
The first line of input contains two integers n and m — the number of friends and the number of pictures that you want to take.
Next line contains n space-separated integers a 1, a 2, …, a n (0 ≤ a i ≤ 109) — the values of attractiveness of the friends.
Output
The only line of output should contain an integer — the optimal total sum of attractiveness of your pictures.
Examples
input
3 1
1 2 3
output
3
input
3 2
1 2 3
output
5
input
3 3
1 2 3
output
6
Solution:
#include <iostream> #include <stdlib.h> #include <iomanip> #include <stdio.h> #include <set> #include <vector> #include <map> #include <cmath> #include <algorithm> #include <memory.h> #include <string> #include <sstream> using namespace std; const int md = 1000000007; int a[55555]; int s[33][55555]; map < vector <int>, int > mp; int count(int l, int r, int ll, int rr, int k, int th) { if (l > r || ll > rr) return 0; vector <int> u(6); u[0] = l; u[1] = r; u[2] = ll; u[3] = rr; u[4] = k; u[5] = th; if (mp.find(u) != mp.end()) return mp[u]; mp[u] = 0; int &res = mp[u]; if (th == 0) return res = (r-l+1)*(rr-ll+1); if (l == ll && r == rr) { if (th & (1 << k)) { for (int i=l; i<=r; i++) if (a[i] & (1 << k)) return res = count(l, i-1, i, r, k-1, th-(1 << k)); return res = 0; } else { for (int i=l; i<=r; i++) if (a[i] & (1 << k)) return res = (i-l)*(r-i+1) + count(l, i-1, l, i-1, k-1, th) + count(i, r, i, r, k-1, th); return res = count(l, r, l, r, k-1, th); } } else { int x = l, y = ll; while (x <= r && (!(a[x] & (1 << k)))) x++; while (y <= rr && (!(a[y] & (1 << k)))) y++; if (th & (1 << k)) { return res = count(l, x-1, y, rr, k-1, th-(1 << k)) + count(x, r, ll, y-1, k-1, th-(1 << k)); } else { return res = (x-l)*(rr-y+1) + (r-x+1)*(y-ll) + count(l, x-1, ll, y-1, k-1, th) + count(x, r, y, rr, k-1, th); } } } int pairwise(int k, int l, int r, int ll, int rr) { if (l > r || ll > rr) return 0; int ans = 0, a = r-l+1, b = rr-ll+1; while (k >= 0) { int x = s[k][r+1]-s[k][l], y = s[k][rr+1]-s[k][ll]; ans = (ans+(long long)(x*(b-y) + (a-x)*y)*(1 << k)) % md; k--; } return ans; } int sum(int l, int r, int ll, int rr, int k, int th) { if (l > r || ll > rr) return 0; if (k == -1) return 0; if (l == ll && r == rr) { if (th & (1 << k)) { for (int i=l; i<=r; i++) if (a[i] & (1 << k)) return ((long long)count(l, i-1, i, r, k-1, th-(1 << k))*(1 << k) + sum(l, i-1, i, r, k-1, th-(1 << k))) % md; return 0; } else { for (int i=l; i<=r; i++) if (a[i] & (1 << k)) return ((long long)pairwise(k, l, i-1, i, r) + sum(l, i-1, l, i-1, k-1, th) + sum(i, r, i, r, k-1, th)) % md; return sum(l, r, l, r, k-1, th); } } else { int x = l, y = ll; while (x <= r && (!(a[x] & (1 << k)))) x++; while (y <= rr && (!(a[y] & (1 << k)))) y++; if (th & (1 << k)) { return ((long long)(count(l, x-1, y, rr, k-1, th-(1 << k)) + count(x, r, ll, y-1, k-1, th-(1 << k)))*(1 << k) + sum(l, x-1, y, rr, k-1, th-(1 << k)) + sum(x, r, ll, y-1, k-1, th-(1 << k))) % md; } else { int ret = ((long long)pairwise(k, l, x-1, y, rr) + pairwise(k, x, r, ll, y-1) + sum(l, x-1, ll, y-1, k-1, th) + sum(x, r, y, rr, k-1, th)) % md; return ret; } } } int main() { // freopen("in", "r", stdin); // freopen("out", "w", stdout); int n, m, i; cin >> n >> m; for (i=0;i<n;i++) scanf("%d", a+i); sort(a, a+n); for (int k=0;k<30;k++) { s[k][0] = 0; for (i=0;i<n;i++) s[k][i+1] = s[k][i] + (!!(a[i] & (1 << k))); } mp.clear(); int ll = 0, rr = (1 << 30)-1; while (ll < rr) { int mid = (ll+rr+1) >> 1; if (count(0, n-1, 0, n-1, 29, mid) >= m) ll = mid; else rr = mid-1; } int ans = sum(0, n-1, 0, n-1, 29, ll); ans = (ans-(long long)(count(0, n-1, 0, n-1, 29, ll)-m)*ll) % md; if (ans < 0) ans += md; printf("%d\n", ans); return 0; }