This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from dataclasses import dataclass | |
from functools import partial | |
from typing import Callable, Any | |
__any__ = ['logpdf', 'sample'] | |
def joker_eye_logpdf(mean: np.ndarray, inv_std: np.ndarray, x: np.ndarray) -> np.ndarray: | |
z = (x - mean) @ inv_std |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Metal | |
using LinearAlgebra | |
using LinearAlgebra: cholesky | |
using Metal: device, global_queue, PrivateStorage, SharedStorage | |
using Metal.MPS: MPSMatrix, MPSMatrixDecompositionStatus, MPSMatrixDecompositionCholesky, MPSMatrixCopy, MPSMatrixCopyDescriptor, MtlFloat, MPSCommandBuffer, encode!, commit!, commitAndContinue!, wait_completed | |
using Metal.MPS: checknonsingular, checkpositivedefinite | |
Metal.ObjectiveC.@autoreleasepool function LinearAlgebra.cholesky(A::MtlMatrix{T}, check::Bool = true, lower::Bool = true) where {T<:MtlFloat} | |
M,N = size(A) | |
@assert M == N |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
function getHeader() { | |
return ` | |
<!doctype html><html> | |
<head> | |
<!--[if gte mso 9]><xml><o:OfficeDocumentSettings><o:AllowPNG/><o:PixelsPerInch>96</o:PixelsPerInch></o:OfficeDocumentSettings></xml><![endif]--> | |
<style> | |
body { | |
background-color: #fff | |
} | |
.gse_alrt_title { |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using SpecialFunctions, Statistics, LinearAlgebra, Random | |
abstract type AbstractIntegralWeight end | |
abstract type AbstractTruncatedGaussian <: AbstractIntegralWeight end | |
struct TruncatedGaussian <: AbstractTruncatedGaussian | |
a::Float64 | |
b::Float64 | |
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This file is intended to create a Gauss quadrature rule with weight `exp(-x)` that includes point `x_1=0`. | |
Frankly, I have no idea if it works | |
============= | |
Copyright 2024, Daniel Sharp | |
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Distributions, SpecialFunctions, MParT, GLMakie | |
# Taken from StatsFuns | |
invsqrt2 = Float64(1 / sqrt(big(2))) | |
normcdf(z) = (erf(z * invsqrt2) + 1) / 2 | |
function sample_banana(n_samples) | |
z = randn(2, n_samples) | |
[z[1, :]'; ((z[2, :]' .+ (z[1, :] .^ 2)') .- 1) / 1.75] | |
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import mpart as mt | |
import numpy as np | |
import pickle | |
from matplotlib import pyplot as plt | |
# Class to represent a GMM | |
# For some reason, there's no good way to evaluate the logpdf and gradlogpdf of a GMM in python | |
class GMM: | |
# Initialize the GMM | |
def __init__(self, weights, centers, vars): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
function hermite_gk22( n ; normalize = true ) | |
#*****************************************************************************80 | |
# | |
## HERMITE_GK22_SET sets a Genz-Keister 22 Hermite rule. | |
# | |
# Discussion: | |
# | |
# The integral: | |
# |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This belongs in JULIA_DEPOT_PATH/config | |
# You usually have to create this folder the first time | |
# type `julia --help` for more information on startup files | |
# Removes all files precompiled for a certain package, forcing it to recompile next time it is loaded | |
function removePrecompile(pkg::String) | |
julia_path = DEPOT_PATH | |
if DEPOT_PATH isa AbstractVector | |
julia_path = DEPOT_PATH[findmin(length, DEPOT_PATH)[2]] | |
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Test methods of TriangularMap object | |
using MParT, Distributions | |
using Random, Printf, LinearAlgebra | |
function Kolmogorov_Smirnov(x) | |
dist = Normal() | |
sorted_samps = sort(x[:]) | |
samps_cdf = cdf(dist, sorted_samps) | |
N = length(sorted_samps) | |
samps_ecdf = (1:N)/N |
NewerOlder