Last active
October 3, 2021 19:00
-
-
Save hehehwang/d058c6fca986a5b479afe10245f63a3e to your computer and use it in GitHub Desktop.
code2seq-jb prediction with single method
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from random import shuffle | |
from typing import Dict, List, Optional | |
from code2seq.data.path_context import LabeledPathContext, Path | |
from code2seq.data.vocabulary import Vocabulary | |
from omegaconf import DictConfig | |
class PathContextConvert: | |
_separator = "|" | |
def __init__(self, vocab, config: DictConfig, random_context: bool): | |
self._config = config | |
self._vocab = vocab | |
self._random_context = random_context | |
def getPathContext(self, sourceCode: str): | |
raw_sample = sourceCode | |
raw_label, *raw_path_contexts = raw_sample.split() | |
n_contexts = min(len(raw_path_contexts), self._config.max_context) | |
if self._random_context: | |
shuffle(raw_path_contexts) | |
raw_path_contexts = raw_path_contexts[:n_contexts] | |
if self._config.max_label_parts == 1: | |
label = self.tokenize_class(raw_label, self._vocab.label_to_id) | |
else: | |
label = self.tokenize_label(raw_label, self._vocab.label_to_id, self._config.max_label_parts) | |
paths = [self._get_path(raw_path.split(",")) for raw_path in raw_path_contexts] | |
return LabeledPathContext(label, paths) | |
@staticmethod | |
def tokenize_class(raw_class: str, vocab: Dict[str, int]) -> List[int]: | |
return [vocab[raw_class]] | |
@staticmethod | |
def tokenize_label(raw_label: str, vocab: Dict[str, int], max_parts: Optional[int]) -> List[int]: | |
sublabels = raw_label.split(PathContextConvert._separator) | |
max_parts = max_parts or len(sublabels) | |
label_unk = vocab[Vocabulary.UNK] | |
label = [vocab[Vocabulary.SOS]] + [vocab.get(st, label_unk) for st in sublabels[:max_parts]] | |
if len(sublabels) < max_parts: | |
label.append(vocab[Vocabulary.EOS]) | |
label += [vocab[Vocabulary.PAD]] * (max_parts + 1 - len(label)) | |
return label | |
@staticmethod | |
def tokenize_token(token: str, vocab: Dict[str, int], max_parts: Optional[int]) -> List[int]: | |
sub_tokens = token.split(PathContextConvert._separator) | |
max_parts = max_parts or len(sub_tokens) | |
token_unk = vocab[Vocabulary.UNK] | |
result = [vocab.get(st, token_unk) for st in sub_tokens[:max_parts]] | |
result += [vocab[Vocabulary.PAD]] * (max_parts - len(result)) | |
return result | |
def _get_path(self, raw_path: List[str]) -> Path: | |
return Path( | |
from_token=self.tokenize_token(raw_path[0], self._vocab.token_to_id, self._config.max_token_parts), | |
path_node=self.tokenize_token(raw_path[1], self._vocab.node_to_id, self._config.path_length), | |
to_token=self.tokenize_token(raw_path[2], self._vocab.token_to_id, self._config.max_token_parts), | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List, cast | |
import torch | |
from code2seq.data.vocabulary import Vocabulary | |
from code2seq.model.code2seq import Code2Seq | |
from omegaconf import OmegaConf | |
from astParser import pyASTParser | |
from getLabeledPathContext import PathContextConvert | |
config = OmegaConf.load('0920.yaml') | |
VOCAB = Vocabulary('vocabulary.pkl', | |
max_labels=config.data['max_labels'], | |
max_tokens=config.data.max_tokens) | |
id_to_label = {idx: lab for (lab, idx) in VOCAB.label_to_id.items()} | |
converter = PathContextConvert(VOCAB, config.data, True) | |
parser = pyASTParser() | |
testCode = """ | |
def getFactorial(n): | |
if n == 0: | |
return 1 | |
else: | |
return n * getFactorial(n-1) | |
""" | |
parser.readSourceCode(testCode) | |
parsed = parser.getParsedContextPaths() | |
s = converter.getPathContext(parsed) | |
def transpose(list_of_lists: List[List[int]]) -> List[List[int]]: | |
return [cast(List[int], it) for it in zip(*list_of_lists)] | |
from_token = torch.tensor(transpose([path.from_token for path in s.path_contexts]), dtype=torch.long) | |
path_nodes = torch.tensor(transpose([path.path_node for path in s.path_contexts]), dtype=torch.long) | |
to_token = torch.tensor(transpose([path.to_token for path in s.path_contexts]), dtype=torch.long) | |
contexts = torch.tensor([len(s.path_contexts)]) | |
c2s = Code2Seq.load_from_checkpoint('chekcpoint.ckpt') | |
output = c2s(from_token=from_token, | |
path_nodes=path_nodes, | |
to_token=to_token, | |
contexts_per_label=contexts, | |
output_length=7) | |
print(output) | |
predictions = output.squeeze(1).argmax(-1) | |
labels = [id_to_label(i.item()) for i in predictions] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment