void recoverTree(TreeNode *root) {
    TreeNode *first=NULL;
    TreeNode *second=NULL;
    TreeNode *prev=NULL;
    
    if(root==NULL) return;
    recoverTreeRe(root, prev, first, second);
    if(first && second)
        std::swap(first->val, second->val);
}

void recoverTreeRe(TreeNode *root, TreeNode *&preNode, TreeNode* &first, TreeNode* &second){
    
    if(root==NULL) return;
    
    recoverTreeRe(root->left, preNode, first, second);
    
    if(preNode && root->val < preNode->val){
    //very careful, might be (0,1), swap two consecutive elements(in order traversal)    
     /*   if(!first)
            first = preNode;
        else
            second = root;*/
       
       second = root;
       if(!first)
            first = preNode;
    }
    preNode = root;//this statement is essential
    
    recoverTreeRe(root->right, preNode, first, second);
    
}