helloworld/
│
├── bin/
│
├── docs/
│ ├── hello.md
│ └── world.md
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# config.yml | |
# YAML support comment(not json) | |
boolean: | |
- TRUE #true,True both are ok | |
- FALSE #false,False both are ok | |
float: | |
- 3.14 | |
- 6.8523015e+5 #support scientific representation | |
int: | |
- 123 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Information Retrieval metrics | |
Useful Resources: | |
http://www.cs.utexas.edu/~mooney/ir-course/slides/Evaluation.ppt | |
http://www.nii.ac.jp/TechReports/05-014E.pdf | |
http://www.stanford.edu/class/cs276/handouts/EvaluationNew-handout-6-per.pdf | |
http://hal.archives-ouvertes.fr/docs/00/72/67/60/PDF/07-busa-fekete.pdf | |
Learning to Rank for Information Retrieval (Tie-Yan Liu) | |
""" | |
import numpy as np |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
from torchvision import datasets, transforms | |
from torch.autograd import Variable | |
# download and transform train dataset | |
train_loader = torch.utils.data.DataLoader(datasets.MNIST('../mnist_data', | |
download=True, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
count 30522.000000 | |
mean 6.584955 | |
std 2.483049 | |
min 1.000000 | |
25% 5.000000 | |
50% 6.000000 | |
75% 8.000000 | |
max 18.000000 | |
median 6 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
predictions = model(inputs) # Forward pass | |
loss = loss_function(predictions, labels) # Compute loss function | |
loss.backward() # Backward pass | |
optimizer.step() # Optimizer step | |
predictions = model(inputs) # Forward pass with new parameters | |
#Here is a simple gist for training a model using gradient accumulation. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#deb cdrom:[Ubuntu 20.04 LTS _Focal Fossa_ - Release amd64 (20200423)]/ focal main restricted | |
# See http://help.ubuntu.com/community/UpgradeNotes for how to upgrade to | |
# newer versions of the distribution. | |
deb http://us.archive.ubuntu.com/ubuntu/ focal main restricted | |
# deb-src http://us.archive.ubuntu.com/ubuntu/ focal main restricted | |
## Major bug fix updates produced after the final release of the | |
## distribution. | |
deb http://us.archive.ubuntu.com/ubuntu/ focal-updates main restricted |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): | |
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering | |
Args: | |
logits: logits distribution shape (vocabulary size) | |
top_k >0: keep only top k tokens with highest probability (top-k filtering). | |
top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). | |
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) | |
""" | |
assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear | |
top_k = min(top_k, logits.size(-1)) # Safety check |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# generate_from_lm.py | |
""" | |
Load a trained language model and generate text | |
Example usage: | |
PYTHONPATH=. python generate_from_lm.py \ | |
--init="Although the food" --tau=0.5 \ | |
--sample_method=gumbel --g_eps=1e-5 \ | |
--load_model='checkpoints/lm/mlstm/hotel/batch_size_64/lm_e9_2.93.pt' \ |
NewerOlder