// Approach #1 DFS + MEMO O(n * 3) T O(n * n) S 
// Approach #2 DP O(n * 3) T O(n * n) S 
// With maxLength optimization it can be O(n * maxLength * maxLength). 

// Approach #1 DFS + MEMO O(n * 3) T O(n * n) S 
public class Solution {
    /**
     * @param s: A string
     * @param dict: A set of word
     * @return: the number of possible sentences.
     */
    public int wordBreak3(String s, Set<String> dict) {
        // handle corner cases
        if (dict == null || dict.size() == 0 || s == null || s.length() == 0) {
            return 0;
        }

        // make it case insensitive
        Set<String> words = new HashSet<>();
        for (String word : dict) {
            words.add(word.toLowerCase());
            maxLength = Math.max(maxLength, word.length());
        }

        // prep for MEMO
        dp = new int[s.length()][s.length()];
        for (int[] row : dp) {
            Arrays.fill(row, -1);
        }
        
        // perform the MEMO search
        return dfs(words, s.toLowerCase(), 0, s.length() - 1);
    }

    private int maxLength;
    private int[][] dp;

    private int dfs(final Set<String> words, final String originalString, final int left, final int right) {
        if (left > right) {
            return 1;
        }
        if (dp[left][right] != -1) {
            return dp[left][right];
        }

        int counter = 0;
        for (int i = left; i <= right && (i - left < maxLength); i++) {
            String word = originalString.substring(left, i + 1);
            
            if (!words.contains(word)) {
                continue;
            }
            counter += dfs(words, originalString, i + 1, right);
        }
        dp[left][right] = counter;

        return dp[left][right];
    }
}

// Approach #2 DP O(n * 3) T O(n * n) S 
public class Solution {
    /**
     * @param s: A string
     * @param dict: A set of word
     * @return: the number of possible sentences.
     */
    public int wordBreak3(String s, Set<String> dict) {
        // handle corner cases
        if (dict == null || dict.size() == 0 || s == null || s.length() == 0) {
            return 0;
        }

        // make words case insensitive 
        Set<String> words = new HashSet<>();
        int maxLength = 0;
        for (String word : dict) {
            words.add(word.toLowerCase());
            maxLength = Math.max(maxLength, word.length());
        }

        // dp initialize
        String string2Break = s.toLowerCase();
        int stringLength = string2Break.length();
        int[][] dp = new int[stringLength][stringLength];
        for (int i = 0; i < stringLength; i++) {
            for (int j = i; j < stringLength && (j - i < maxLength); j++) {
                String sub = string2Break.substring(i, j + 1);
                if (words.contains(sub)) {
                    dp[i][j] = 1;
                }
            }
        }

        // dp fill the rest with dual-for loop
        for (int i = 0; i < stringLength; i++) {
            for (int j = i; j < stringLength; j++) {
                for (int k = i; k < j; k++) {
                    dp[i][j] += dp[i][k] * dp[k + 1][j];
                }
            }
        }

        return dp[0][stringLength - 1];
    }
}