class Solution {
public:
    int minTotalDistance(vector<vector<int>>& grid) {
        if(grid.size()==0)
            return 0;
        int rows = grid.size(); 
        int cols = grid[0].size();
        
        vector<int> X(cols);
        vector<int> Y(rows);
        
        for(int i=0; i<rows; i++){
            for(int j=0; j<cols; j++){
                if(grid[i][j]==0)
                    continue;
                int x = j;
                int y=  i;
                
                for(int k=0; k<cols; k++)
                    X[k]=X[k]+abs(x--);
                for(int k=0; k<rows; k++)
                    Y[k]=Y[k]+abs(y--);
            }
        }
        
        int minv = INT_MAX;
        for(int i=0; i<rows; i++)
            for(int j=0; j<cols; j++)
                minv = min(minv, X[j]+Y[i]);
        return minv;
    }
};