Skip to content

Instantly share code, notes, and snippets.

@unixpickle
Last active December 8, 2017 22:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save unixpickle/2bb94237eccebfc9490417d0b54c7419 to your computer and use it in GitHub Desktop.
Save unixpickle/2bb94237eccebfc9490417d0b54c7419 to your computer and use it in GitHub Desktop.
Match images to ones in ImageNet
// Find miniImageNet images in the original dataset.
package main
import (
"fmt"
"image"
"image/jpeg"
"io"
"log"
"os"
"path/filepath"
"runtime"
"sync"
"github.com/unixpickle/num-analysis/linalg"
"github.com/unixpickle/resize"
)
const (
ImagesDir = "/home/alex/images"
ImageNetDir = "/media/NAS_SHARED/imagenet/ILSVRC2012_img_train_t1_t2"
)
func main() {
hashes := map[string][]*HashedImage{}
for name := range ReadDir(ImagesDir) {
if filepath.Ext(name) != ".jpg" {
continue
}
className := name[:9]
if _, ok := hashes[className]; !ok {
log.Println("Hashing ImageNet class", className)
hashes[className] = HashClass(className)
}
hash := HashImage(ReadImage(filepath.Join(ImagesDir, name)))
bestDot := 0.0
bestName := "<unknown>"
for _, option := range hashes[className] {
dot := option.Hash.Dot(hash)
if dot > bestDot {
bestDot = dot
bestName = option.Name
}
}
if bestDot < 0.9 {
log.Println("WARNING: low correlation for:", name, "-",
bestName, bestDot)
}
fmt.Printf("%s/%s (correlation=%f)\n", className, bestName, bestDot)
}
}
func HashImage(img image.Image) linalg.Vector {
resized := resize.Resize(64, 64, img, resize.Bilinear)
vec := make(linalg.Vector, 0, 64*64*3)
for x := 0; x < 64; x++ {
for y := 0; y < 64; y++ {
r, g, b, _ := resized.At(x, y).RGBA()
vec = append(vec, float64(r), float64(g), float64(b))
}
}
return vec.Scale(1 / (vec.Mag() + 1e-5))
}
func HashClass(name string) []*HashedImage {
res := []*HashedImage{}
var lock sync.Mutex
var wg sync.WaitGroup
listing := ReadDir(filepath.Join(ImageNetDir, name))
for i := 0; i < runtime.GOMAXPROCS(0); i++ {
wg.Add(1)
go func() {
defer wg.Done()
for imageName := range listing {
if filepath.Ext(imageName) != ".JPEG" {
continue
}
path := filepath.Join(ImageNetDir, name, imageName)
hash := HashImage(ReadImage(path))
lock.Lock()
res = append(res, &HashedImage{Name: imageName, Hash: hash})
lock.Unlock()
}
}()
}
wg.Wait()
return res
}
func ReadImage(path string) image.Image {
f, err := os.Open(path)
if err != nil {
panic(err)
}
defer f.Close()
res, err := jpeg.Decode(f)
if err != nil {
panic(path + ": " + err.Error())
}
return res
}
func ReadDir(path string) <-chan string {
res := make(chan string, 1)
go func() {
defer close(res)
handle, err := os.Open(path)
if err != nil {
panic(err)
}
defer handle.Close()
for {
next, err := handle.Readdirnames(100)
if err == io.EOF {
return
} else if err != nil {
panic(err)
}
for _, item := range next {
res <- item
}
}
}()
return res
}
type HashedImage struct {
Name string
Hash linalg.Vector
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment