感觉很麻烦的一道题!
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
struct TreeNode {
int val;
TreeNode* left;
TreeNode* right;
TreeNode() : val(0), left(nullptr), right(nullptr) {};
TreeNode(int _val) : val(_val), left(nullptr), right(nullptr) {};
TreeNode(int _val, TreeNode* _left, TreeNode* _right) : val(_val), left(_left), right(_right) {};
~TreeNode() {};
};
int n;
TreeNode* construct1(vector<int>& preorder, vector<int>& inorder, int i, int j, int n) {
if (n <= 0) return (TreeNode*) nullptr;
if (n == 1) return new TreeNode(preorder[i]);
int k = j;
while (k < j + n && inorder[k] != preorder[i]) ++k;
if (k == j + n) return (TreeNode*) nullptr;
const int l = k - j; // l == 左子树的长度 ...
auto root = new TreeNode(inorder[k]);
root->left = construct1(preorder, inorder, i + 1, j, l);
root->right = construct1(preorder, inorder, i + 1 + l, k + 1, n - l - 1);
return root;
}
TreeNode* construct2(vector<int>& preorder, vector<int>& inorder, int i, int j, int n) {
if (n <= 0) return (TreeNode*) nullptr;
if (n == 1) return new TreeNode(preorder[i]);
int k = j + n - 1;
while (k >= j && inorder[k] != preorder[i]) --k;
if (k < j) return (TreeNode*) nullptr;
const int l = k - j; // l == 左子树的长度 ...
auto root = new TreeNode(inorder[k]);
root->left = construct2(preorder, inorder, i + 1, j, l);
root->right = construct2(preorder, inorder, i + 1 + l, k + 1, n - l - 1);
return root;
}
TreeNode* buildTree(vector<int>& preorder, vector<int>& inorder) {
return construct1(preorder, inorder, 0, 0, preorder.size());
}
TreeNode* buildTree2(vector<int>& preorder, vector<int>& inorder) {
return construct2(preorder, inorder, 0, 0, preorder.size());
}
void inOrder(TreeNode* root, vector<int>& vals) {
if (!root) return;
inOrder(root->left, vals);
vals.emplace_back(root->val);
inOrder(root->right, vals);
}
void postOrder(TreeNode* root, vector<int>& ans) {
if (!root) return;
postOrder(root->left, ans);
postOrder(root->right, ans);
ans.emplace_back(root->val);
}
void printAns(vector<int>& ans) {
for (int i = 0; i < ans.size(); ++i) {
printf("%d", ans[i]);
if (i < ans.size() - 1) printf(" ");
}
printf("\n");
}
int main(void) {
cin >> n;
vector<int> preorder(n);
for (int i = 0; i < n; ++i) cin >> preorder[i];
vector<int> inorder(preorder);
sort(begin(inorder), end(inorder));
auto root = buildTree(preorder, inorder);
vector<int> v;
inOrder(root, v);
if (v == inorder) {
puts("YES");
vector<int> ans;
postOrder(root, ans);
printAns(ans);
} else {
sort(rbegin(inorder), rend(inorder));
root = buildTree2(preorder, inorder);
v.clear();
inOrder(root, v);
if (v == inorder) {
puts("YES");
vector<int> ans;
postOrder(root, ans);
printAns(ans);
} else {
puts("NO");
}
}
}