Skip to content

Instantly share code, notes, and snippets.

@hehehwang
Last active October 3, 2021 19:00
Show Gist options
  • Save hehehwang/d058c6fca986a5b479afe10245f63a3e to your computer and use it in GitHub Desktop.
Save hehehwang/d058c6fca986a5b479afe10245f63a3e to your computer and use it in GitHub Desktop.
code2seq-jb prediction with single method
from random import shuffle
from typing import Dict, List, Optional
from code2seq.data.path_context import LabeledPathContext, Path
from code2seq.data.vocabulary import Vocabulary
from omegaconf import DictConfig
class PathContextConvert:
_separator = "|"
def __init__(self, vocab, config: DictConfig, random_context: bool):
self._config = config
self._vocab = vocab
self._random_context = random_context
def getPathContext(self, sourceCode: str):
raw_sample = sourceCode
raw_label, *raw_path_contexts = raw_sample.split()
n_contexts = min(len(raw_path_contexts), self._config.max_context)
if self._random_context:
shuffle(raw_path_contexts)
raw_path_contexts = raw_path_contexts[:n_contexts]
if self._config.max_label_parts == 1:
label = self.tokenize_class(raw_label, self._vocab.label_to_id)
else:
label = self.tokenize_label(raw_label, self._vocab.label_to_id, self._config.max_label_parts)
paths = [self._get_path(raw_path.split(",")) for raw_path in raw_path_contexts]
return LabeledPathContext(label, paths)
@staticmethod
def tokenize_class(raw_class: str, vocab: Dict[str, int]) -> List[int]:
return [vocab[raw_class]]
@staticmethod
def tokenize_label(raw_label: str, vocab: Dict[str, int], max_parts: Optional[int]) -> List[int]:
sublabels = raw_label.split(PathContextConvert._separator)
max_parts = max_parts or len(sublabels)
label_unk = vocab[Vocabulary.UNK]
label = [vocab[Vocabulary.SOS]] + [vocab.get(st, label_unk) for st in sublabels[:max_parts]]
if len(sublabels) < max_parts:
label.append(vocab[Vocabulary.EOS])
label += [vocab[Vocabulary.PAD]] * (max_parts + 1 - len(label))
return label
@staticmethod
def tokenize_token(token: str, vocab: Dict[str, int], max_parts: Optional[int]) -> List[int]:
sub_tokens = token.split(PathContextConvert._separator)
max_parts = max_parts or len(sub_tokens)
token_unk = vocab[Vocabulary.UNK]
result = [vocab.get(st, token_unk) for st in sub_tokens[:max_parts]]
result += [vocab[Vocabulary.PAD]] * (max_parts - len(result))
return result
def _get_path(self, raw_path: List[str]) -> Path:
return Path(
from_token=self.tokenize_token(raw_path[0], self._vocab.token_to_id, self._config.max_token_parts),
path_node=self.tokenize_token(raw_path[1], self._vocab.node_to_id, self._config.path_length),
to_token=self.tokenize_token(raw_path[2], self._vocab.token_to_id, self._config.max_token_parts),
)
from typing import List, cast
import torch
from code2seq.data.vocabulary import Vocabulary
from code2seq.model.code2seq import Code2Seq
from omegaconf import OmegaConf
from astParser import pyASTParser
from getLabeledPathContext import PathContextConvert
config = OmegaConf.load('0920.yaml')
VOCAB = Vocabulary('vocabulary.pkl',
max_labels=config.data['max_labels'],
max_tokens=config.data.max_tokens)
id_to_label = {idx: lab for (lab, idx) in VOCAB.label_to_id.items()}
converter = PathContextConvert(VOCAB, config.data, True)
parser = pyASTParser()
testCode = """
def getFactorial(n):
if n == 0:
return 1
else:
return n * getFactorial(n-1)
"""
parser.readSourceCode(testCode)
parsed = parser.getParsedContextPaths()
s = converter.getPathContext(parsed)
def transpose(list_of_lists: List[List[int]]) -> List[List[int]]:
return [cast(List[int], it) for it in zip(*list_of_lists)]
from_token = torch.tensor(transpose([path.from_token for path in s.path_contexts]), dtype=torch.long)
path_nodes = torch.tensor(transpose([path.path_node for path in s.path_contexts]), dtype=torch.long)
to_token = torch.tensor(transpose([path.to_token for path in s.path_contexts]), dtype=torch.long)
contexts = torch.tensor([len(s.path_contexts)])
c2s = Code2Seq.load_from_checkpoint('chekcpoint.ckpt')
output = c2s(from_token=from_token,
path_nodes=path_nodes,
to_token=to_token,
contexts_per_label=contexts,
output_length=7)
print(output)
predictions = output.squeeze(1).argmax(-1)
labels = [id_to_label(i.item()) for i in predictions]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment