每个self.optimizer.param_group里所有的param flatten后, 切分到各个dp rank上 self.single_partition_of_fp32_groups即属于该dp rank的params initialize_optimizer_states: per single_partition_of_fp32_groups创建对应的optimizer master weight和grad
self.round_robin_gradients 在切分前shuffle params, 这样torch module相邻的param可以属于不同的dp rank, 从而gradient allreduce后每个dp rank都能有属于自己负责的gradient