/**
 * Definition for binary tree
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode(int x) { val = x; }
 * }
 */
public class Solution {
    TreeNode first = null;
    TreeNode second = null;
    TreeNode afterFirst = null;
    public void recoverTree(TreeNode root) {
        inOrder(root, null);
        if (second == null) {
            second = afterFirst;
        }
        int temp = first.val;
        first.val = second.val;
        second.val = temp;
    }
    
    public TreeNode inOrder(TreeNode root, TreeNode prev) {
        if (root == null) {
            return prev;
        }
        prev = inOrder(root.left, prev);
        if (prev != null && prev.val > root.val) {
            if (first == null) {
                first = prev;
                afterFirst = root;
            } else {
                second = root;
            }
        }
        return inOrder(root.right, root);
    }
}