题目链接:three arrays

​ 已知长度均为 $n$ 的数组 $A, B$,重新排列数组 $A,B$ 后,计算数组 $C$,使得 $C_{i}=A_{i}\ XOR\ B_{i}$,输出字典序最小的 $C$。

​ 对于数组 $A,B$ 建立两棵01字典树,从 $A$ 中随便找个数字 $x_{0}$,在 $B$ 中找最接近 $x_{0}$ 的数字 $x_{1}$,再在 $A$ 中找最接近 $x_{1}$ 的数字 $x_{2}$……这样一直进行下去,会出现大小为 $2$ 的循环节,这时候将这两个数字异或起来,就是我们的一个答案。但是这样我们不能保证答案有序,因此求出每一个答案之后再排序一下输出。

#include<bits/stdc++.h>

using namespace std;
const int maxn = 1e5 + 10;
const int maxm = 2e6 + 10;

struct Trie {
    int ch[maxm][2]{}, num[maxm]{}, val[maxm]{}, cnt = 1;

    void insert(int rt, int dep, int x) {
        if (dep == 30)
            return val[rt] = x, void();
        int t = (x & (1 << (29 - dep))) > 0;
        if (!ch[rt][t])
            ch[rt][t] = ++cnt, num[cnt] = 1;
        else
            num[ch[rt][t]]++;
        insert(ch[rt][t], dep + 1, x);
    }

    int xjbfind(int rt, int dep) {
        if (dep == 30)
            return val[rt];
        if (!ch[rt][0] && !ch[rt][1])
            return -1;
        if (ch[rt][0])
            return xjbfind(ch[rt][0], dep + 1);
        return xjbfind(ch[rt][1], dep + 1);
    }

    int query(int rt, int dep, int x) {
        if (dep == 30)
            return val[rt];
        int t = (x & (1 << (29 - dep))) > 0;
        if (ch[rt][t])
            return query(ch[rt][t], dep + 1, x);
        return query(ch[rt][t ^ 1], dep + 1, x);
    }

    void del(int rt, int dep, int x) {
        if (dep == 30)
            return;
        int t = (x & (1 << (29 - dep))) > 0;
        del(ch[rt][t], dep + 1, x);
        if (!(--num[ch[rt][t]]))
            ch[rt][t] = 0;
    }

    void clear() {
        for (int i = 0; i <= cnt; i++)
            ch[i][0] = ch[i][1] = num[i] = val[i] = 0;
        cnt = 1;
    }
} trie[2];

int a[maxn], b[maxn], ans[maxn], n;

int dfs(int id, int x, int la) {
    int t = trie[id ^ 1].query(1, 0, x);
    if (t == la)
    {
        trie[id].del(1, 0, x);
        trie[id ^ 1].del(1, 0, t);
        return t ^ x;
    }
    return dfs(id ^ 1, t, x);
}

int main() {
    int T;
    scanf("%d", &T);
    while (T--)
    {
        scanf("%d", &n);
        for (int i = 0; i < n; i++)
            scanf("%d", a + i), trie[0].insert(1, 0, a[i]);
        for (int i = 0; i < n; i++)
            scanf("%d", b + i), trie[1].insert(1, 0, b[i]);
        for (int i = 0; i < n; i++)
        {
            int t = trie[0].xjbfind(1, 0);
            ans[i] = dfs(1, t, -1);
        }
        sort(ans, ans + n);
        for (int i = 0; i < n; i++)
            printf("%d%c", ans[i], "\n "[i < n - 1]);
        trie[0].clear(), trie[1].clear();
    }
}