This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/// Returns the current CPU time in nanoseconds. | |
/// | |
/// # Panics | |
/// | |
/// Panics if the system call to get the CPU time fails. | |
fn thread_time_ns() -> u128 { | |
let mut time = libc::timespec { tv_sec: 0, tv_nsec: 0 }; | |
unsafe { | |
assert_ne!( |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::{borrow::Borrow, hash::Hash}; | |
use rustc_hash::FxHashMap; | |
pub struct Trie<K, V> { | |
root: TrieNode<K, V>, | |
} | |
#[derive(Debug, Clone)] | |
struct TrieNode<K, V> { |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/// Transforms the moves into a binary tree to group them, and then selects a | |
/// candidate to explore using Thompson Sampling. | |
/// | |
/// This move grouping takes [1] to its logical conclusion by forcing the policy | |
/// to always choose one of two groups. | |
/// | |
/// [1]: https://webdocs.cs.ualberta.ca/~mmueller/ps/GVanEyck-MoveGroups-Final.pdf | |
fn thompson_sampling_binary_tree_policy(tree: &[MctsNode], rng: &mut impl Rng) -> usize { | |
/// The maximum branching factor. | |
const B: usize = 128; |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
pub(crate) fn partition_in_place<T, P>(slice: &mut [T], mut predicate: P) -> usize | |
where | |
P: FnMut(&T) -> bool, | |
{ | |
let mut head = 0; | |
let mut last = slice.len(); | |
while head < last { | |
if predicate(&slice[head]) { | |
head += 1; |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use rand::Rng; | |
/// Chooses an element from an iterator with weights. | |
/// | |
/// See <https://en.wikipedia.org/wiki/Reservoir_sampling#Algorithm_A-Chao>. | |
#[expect(dead_code, reason = "Maybe useful in the future.")] | |
pub(crate) fn choose_weighted<I, R, F>(mut iter: I, rng: &mut R, mut weight: F) -> Option<I::Item> | |
where | |
R: Rng, | |
I: Iterator, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
def _round_ste(x: torch.Tensor) -> torch.Tensor: | |
return x + (torch.round(x) - x).detach() | |
class FSQ(nn.Module): | |
r""" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
from torch.nn import functional as fn | |
class LearnedGaussianNLL(nn.Module): | |
def __init__(self, log_var: torch.Tensor): | |
super().__init__() | |
self.log_var = nn.Parameter(log_var, requires_grad=True) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import torch | |
def haar_wavelet_packet(input: torch.Tensor, dim: int) -> torch.Tensor: | |
return _haar_wavelet_packet(input.movedim(dim, 0)).movedim(0, dim) | |
def _haar_wavelet_packet( |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
def random_projection_matrix(in_features: int, out_features: int) -> torch.Tensor: | |
r""" | |
See <https://en.wikipedia.org/wiki/Random_projection#More_computationally_efficient_random_projections>. | |
""" | |
p = torch.empty(in_features, out_features).bernoulli_() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
from torch.distributions import Categorical | |
class TimeJitter(nn.Module): | |
r"""Randomly replaces latent vectors with its neighbors to prevent co-adaptation. | |
See <https://arxiv.org/pdf/1901.08810#page=9>. | |
""" |
NewerOlder