This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | import functools | |
| import os | |
| import shutil | |
| import subprocess | |
| from contextlib import contextmanager | |
| from typing import Any, Callable, Dict, Iterator, Tuple | |
| from dotenv import load_dotenv | |
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | from typing import Tuple | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from numpy.typing import NDArray | |
| import pandas as pd | |
| from sklearn.metrics import ( | |
| ConfusionMatrixDisplay, | |
| accuracy_score, | |
| auc, | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | import torch | |
| # Batch matrix multiplication | |
| X = torch.arange(24).reshape(2, 3, 4) | |
| Y = torch.arange(40).reshape(2, 4, 5) | |
| A = torch.einsum('ijk, ikl->ijl', X, Y) | |
| torch.bmm(X, Y) | |
| print(A) | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | import torch | |
| # Matrix-Matrix multiplication | |
| X = torch.arange(6).reshape(2, 3) | |
| Y = torch.arange(12).reshape(3, 4) | |
| A = torch.einsum('ij, jk->ik', X, Y) | |
| torch.mm(X, Y) | |
| print(A) | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | import torch | |
| # Matrix-Vector multiplication | |
| X = torch.rand((3, 3)) | |
| y = torch.rand((1, 3)) | |
| A = torch.einsum('ij, kj->ik', X, y) | |
| torch.mm(X, torch.transpose(y, 0, 1)) # or torch.mm(X, y.T) | |
| print(A) | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | import torch | |
| # Outer product | |
| v = torch.rand((3)) | |
| t = torch.rand((3)) | |
| A = torch.einsum('i, j->ij', v, t) | |
| torch.outer(v, t) | |
| print(A) | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | import torch | |
| # Dot product | |
| v = torch.rand((3)) | |
| c = torch.rand((3)) | |
| a = torch.einsum('i, i->', v, c) | |
| torch.dot(v, c) | |
| print(a) # tensor(0.5750) | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | import torch | |
| # Element wise multiplication | |
| X = torch.rand((3, 2)) | |
| Y = torch.rand((3, 2)) | |
| A = torch.einsum('ij, ij->ij', X, Y) | |
| torch.mul(X, Y) # or X * Y | |
| print(A) | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | import torch | |
| X = torch.rand((2, 3)) | |
| # Row summation | |
| a = torch.einsum('ij->i', X) | |
| torch.sum(X, axis=1) | |
| print(a) # tensor([1.4088, 1.7803]) | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | import torch | |
| # Summation | |
| X = torch.rand((2, 3)) | |
| a = torch.einsum('ij->', X) | |
| torch.sum(X) | |
| print(a) # tensor(3.5585) | 
NewerOlder