Last active
December 30, 2019 10:41
-
-
Save khirotaka/a4953d6e99223220dc92d71c2d166166 to your computer and use it in GitHub Desktop.
fixed nn.MultiheadAttention on PyTorch v.1.2.0 Input Completion for PyCharm.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from ... import Tensor | |
from .. import Parameter | |
from .module import Mudule | |
from typinh import Any, Optional | |
""" | |
add this code to | |
https://github.com/pytorch/pytorch/blob/ff7921e85bad0ad47bc7fa6d48c2f8762cf3f6b3/torch/nn/modules/__init__.pyi.in#L2-L6 | |
""" | |
from .activation import CELU as CELU, ELU as ELU, GLU as GLU, Hardshrink as Hardshrink, Hardtanh as Hardtanh, \ | |
LeakyReLU as LeakyReLU, LogSigmoid as LogSigmoid, LogSoftmax as LogSoftmax, PReLU as PReLU, RReLU as RReLU, \ | |
ReLU as ReLU, ReLU6 as ReLU6, SELU as SELU, Sigmoid as Sigmoid, Softmax as Softmax, Softmax2d as Softmax2d, \ | |
Softmin as Softmin, Softplus as Softplus, Softshrink as Softshrink, Softsign as Softsign, Tanh as Tanh, \ | |
Tanhshrink as Tanhshrink, Threshold as Threshold, MultiheadAttention as MultiheadAttention # <- Add |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from ... import Tensor | |
from .. import Parameter | |
from .module import Module | |
from typing import Any, Optional | |
""" | |
add this code to | |
https://github.com/pytorch/pytorch/blob/ff7921e85bad0ad47bc7fa6d48c2f8762cf3f6b3/torch/nn/modules/activation.pyi.in#L1-L167 | |
""" | |
class MultiheadAttention(Module): | |
def __init__(self, embed_dim: int, num_heads: int, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None) -> None: ... | |
def _reset_parameters(self) -> None: ... | |
def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask=None, need_weights=True, attn_mask=None) -> Tensor: ... | |
def __call__(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask=None, need_weights=True, attn_mask=None) -> Tensor: ... |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment