Skip to content

Instantly share code, notes, and snippets.

@khirotaka
Last active December 30, 2019 10:41
Show Gist options
  • Save khirotaka/a4953d6e99223220dc92d71c2d166166 to your computer and use it in GitHub Desktop.
Save khirotaka/a4953d6e99223220dc92d71c2d166166 to your computer and use it in GitHub Desktop.
fixed nn.MultiheadAttention on PyTorch v.1.2.0 Input Completion for PyCharm.
from ... import Tensor
from .. import Parameter
from .module import Mudule
from typinh import Any, Optional
"""
add this code to
https://github.com/pytorch/pytorch/blob/ff7921e85bad0ad47bc7fa6d48c2f8762cf3f6b3/torch/nn/modules/__init__.pyi.in#L2-L6
"""
from .activation import CELU as CELU, ELU as ELU, GLU as GLU, Hardshrink as Hardshrink, Hardtanh as Hardtanh, \
LeakyReLU as LeakyReLU, LogSigmoid as LogSigmoid, LogSoftmax as LogSoftmax, PReLU as PReLU, RReLU as RReLU, \
ReLU as ReLU, ReLU6 as ReLU6, SELU as SELU, Sigmoid as Sigmoid, Softmax as Softmax, Softmax2d as Softmax2d, \
Softmin as Softmin, Softplus as Softplus, Softshrink as Softshrink, Softsign as Softsign, Tanh as Tanh, \
Tanhshrink as Tanhshrink, Threshold as Threshold, MultiheadAttention as MultiheadAttention # <- Add
from ... import Tensor
from .. import Parameter
from .module import Module
from typing import Any, Optional
"""
add this code to
https://github.com/pytorch/pytorch/blob/ff7921e85bad0ad47bc7fa6d48c2f8762cf3f6b3/torch/nn/modules/activation.pyi.in#L1-L167
"""
class MultiheadAttention(Module):
def __init__(self, embed_dim: int, num_heads: int, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None) -> None: ...
def _reset_parameters(self) -> None: ...
def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask=None, need_weights=True, attn_mask=None) -> Tensor: ...
def __call__(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask=None, need_weights=True, attn_mask=None) -> Tensor: ...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment