Skip to content

Instantly share code, notes, and snippets.

@qedawkins
Created September 13, 2022 05:00
Show Gist options
  • Save qedawkins/5533c90d54988e4d13ac17528f49bffe to your computer and use it in GitHub Desktop.
Save qedawkins/5533c90d54988e4d13ac17528f49bffe to your computer and use it in GitHub Desktop.
import torch
from torch import nn
from torchrec.datasets.utils import Batch
from torchrec.modules.crossnet import LowRankCrossNet
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from typing import Dict, List, Optional, Tuple
from torchrec.models.dlrm import (
choose,
DenseArch,
DLRM,
InteractionArch,
SparseArch,
OverArch,
)
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
import numpy as np
torch.manual_seed(0)
np.random.seed(0)
class ToyEmbeddingBag(nn.Module):
def __init__(self, num_embeddings, embedding_dim):
super().__init__()
self.embedding = nn.EmbeddingBag(num_embeddings, embedding_dim, mode="sum")
W = np.random.uniform(
low=-np.sqrt(1 / num_embeddings),
high=np.sqrt(1 / num_embeddings),
size=(num_embeddings, embedding_dim),
).astype(np.float32)
self.embedding.weight.data = torch.tensor(W, requires_grad=True)
def forward(self, vals, offsets):
return self.embedding(vals, offsets)
def test_embedding() -> None:
# print(logits)
# print(logits_nod)
toy = ToyEmbeddingBag(10, 3)
values = torch.tensor([1, 2, 4, 5], dtype=torch.int64)
offsets = torch.tensor([0, 1], dtype=torch.int64)
# Import the module and print.
mlir_importer = SharkImporter(
toy,
(values, offsets),
frontend="torch",
)
(dlrm_mlir, func_name), inputs, golden_out = mlir_importer.import_debug(
tracing_required=True, dir="/home/quinn/tmp"
)
shark_module = SharkInference(
dlrm_mlir, func_name, device="intel-gpu", mlir_dialect="linalg"
)
shark_module.compile()
inputs = (values, offsets)
result = shark_module.forward(inputs)
golden_out = toy(values, offsets).detach()
np.testing.assert_allclose(golden_out, result, rtol=1e-02, atol=1e-03)
test_embedding()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment