Skip to content

Instantly share code, notes, and snippets.

@rdednl
Last active September 7, 2023 15:01
Show Gist options
  • Save rdednl/64e8fb4b7d4a0e4d047f91188cbfaaed to your computer and use it in GitHub Desktop.
Save rdednl/64e8fb4b7d4a0e4d047f91188cbfaaed to your computer and use it in GitHub Desktop.
batch norm is bad (td3/sac)
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@honglu2875
Copy link

honglu2875 commented Aug 17, 2022

Well, when you updated the target net, you used parameters(). I can see the code is from stable-baseline which is not designed for bn. Batch norm in fact has two more variables that are not included in parameters(). What you had was in fact a partially updated target net which is dramatically different when inferencing.

@honglu2875
Copy link

Also, you abused .eval() and .train(). I suggest you learn about how batch norm works and what are the implications of those methods. I didn't look into the exact logic of each algorithms but I also have a vague feeling that something is not exactly like in the paper (but might still work).

ps: I was just able to fix your TD3 and made the batch norm model run as good as the one without. I don't have time to do SAC for you but I think most likely you should be able to fix it by yourself too if you understand batch norm correctly.

@rdednl
Copy link
Author

rdednl commented Sep 29, 2022

@honglu2875 Hi. My code is not from stable baselines. Also, batch norm learnable parameters that have to be updated on the target are present in the parameters() method:

> for model_param in model.model.actor.layers[1].parameters():
>    print(model_param.shape)

torch.Size([64])
torch.Size([64])

what are the variables that are missing?

Also, what do you mean that I abused .eval() and .train() ?

@honglu2875
Copy link

honglu2875 commented Sep 29, 2022

Check out properties whose names start with "running_" (either in your batch norm layer or state_dict). They are "learnable", meaning they change under training but not by gradients. They are not present in parameters().

All learnable parameters are in state_dict(). parameters() are only those that are updated by gradients.

@honglu2875
Copy link

honglu2875 commented Sep 29, 2022

My code is not from stable baselines.

Ahh.... So this misunderstanding spread wider than I thought... Maybe there is a chain of misuse and people never bother checking.
When stable-baseline came out there was no such thing as batch norm by the way. The code is great and should indeed be our implement baseline. But we, "the later generations", really have more responsibilities when working on earlier codes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment