Skip to content

Instantly share code, notes, and snippets.

@catdingding
Last active September 26, 2023 08:32
Show Gist options
  • Save catdingding/e1f539add03be46cb428ed19d736e782 to your computer and use it in GitHub Desktop.
Save catdingding/e1f539add03be46cb428ed19d736e782 to your computer and use it in GitHub Desktop.
import torch
import numpy as np
from collections.abc import Iterable
from tqdm.auto import tqdm
from multiprocessing import Pool

from vina2vi.models.char_based.bigram import Bigram
from vina2vi.util import (
    Vietnamese,
    uncased_vina_normalizer,
    cased_vi_normalizer,
)


def count_np(s: str):
    count_matrix = np.zeros((len(Bigram.itoc), len(Bigram.itoc)), dtype=np.int32)
    if s != "":
        # Without normalization, one may obtain a very different count matrix
        s = cased_vi_normalizer.normalize_str(s)
        tokens = list(s.lower())
        unk_index = Bigram.ctoi[Bigram.unk_token]

        tokens_index = [Bigram.ctoi.get(token, unk_index) for token in tokens]
        l1 = [Bigram.ctoi.get(Bigram.bos_token, unk_index)] + tokens_index
        l2 = tokens_index + [Bigram.ctoi.get(Bigram.eos_token, unk_index)]
        np.add.at(count_matrix, (l1, l2), 1)
    return count_matrix

class BigramNew(Bigram):
    def fit(
        self,
        data: Iterable[str],
        *,
        total: int | None = None,
        chunksize: int = 1,
    ) -> None:
        # Multiprocessing pool idea borrowed from mCoding
        # https://www.youtube.com/watch?v=X7vBbelRXn0&t=280s
        with Pool() as pool:
            # Unable to use a method like self.count in imap_unordered() here
            # because the class Bigram contains a torch.Generator, which is not picklable.
            matrices = pool.imap_unordered(
                count_np,
                tqdm(data, total=total),
                chunksize=chunksize,
            )
            for matrix in matrices:
                self.count_matrix += torch.from_numpy(matrix)

        self.update_proba_matrix()

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment