Created
November 2, 2019 20:48
-
-
Save unixpickle/402e0646fc0ea0faf08904b60ddd4b69 to your computer and use it in GitHub Desktop.
3D model generator data source
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"compress/gzip" | |
"fmt" | |
"math" | |
"os" | |
"github.com/unixpickle/model3d" | |
) | |
// LoadCollider reads the model at the given path and | |
// returns a collider for it. | |
// | |
// If useCache is true, then extra files may bo read or | |
// created alongside the file to help improve performance | |
// when loading this model in the future. | |
func LoadCollider(path string, useCache bool) (model3d.Collider, error) { | |
cachePath := path + "_cache.stl.gz" | |
if useCache { | |
if triangles := ReadCache(cachePath); triangles != nil { | |
return model3d.GroupedTrianglesToCollider(triangles), nil | |
} | |
} | |
f, err := os.Open(path) | |
if err != nil { | |
return nil, err | |
} | |
defer f.Close() | |
triangles, err := model3d.ReadOFF(f) | |
if err != nil { | |
return nil, fmt.Errorf("failed to read %s: %s", path, err) | |
} | |
CenterAndScale(triangles) | |
model3d.GroupTriangles(triangles) | |
if useCache { | |
WriteCache(cachePath, triangles) | |
} | |
return model3d.GroupedTrianglesToCollider(triangles), nil | |
} | |
func CenterAndScale(triangles []*model3d.Triangle) { | |
min := triangles[0].Min() | |
max := triangles[0].Max() | |
for _, t := range triangles { | |
min = min.Min(t.Min()) | |
max = max.Max(t.Max()) | |
} | |
offset := min.Scale(-1) | |
sizes := max.Add(offset) | |
scale := 1 / math.Max(math.Max(sizes.X, sizes.Y), sizes.Z) | |
sizes = sizes.Scale(scale) | |
centerOffset := model3d.Coord3D{X: 1, Y: 1, Z: 1}.Add(sizes.Scale(-scale)).Scale(0.5) | |
for _, t := range triangles { | |
for i, p := range t { | |
t[i] = p.Add(offset).Scale(scale).Add(centerOffset) | |
} | |
} | |
} | |
func WriteCache(path string, t []*model3d.Triangle) { | |
tmpPath := path + ".tmp" | |
w, err := os.Create(tmpPath) | |
if err != nil { | |
return | |
} | |
zipWriter := gzip.NewWriter(w) | |
_, err = zipWriter.Write(model3d.EncodeSTL(t)) | |
if err == nil { | |
err = zipWriter.Close() | |
} else { | |
zipWriter.Close() | |
} | |
if err == nil { | |
err = w.Close() | |
} else { | |
w.Close() | |
} | |
if err != nil { | |
os.Remove(tmpPath) | |
return | |
} | |
os.Rename(tmpPath, path) | |
} | |
func ReadCache(path string) []*model3d.Triangle { | |
f, err := os.Open(path) | |
if err != nil { | |
return nil | |
} | |
defer f.Close() | |
zipReader, err := gzip.NewReader(f) | |
if err != nil { | |
return nil | |
} | |
defer zipReader.Close() | |
triangles, err := model3d.ReadSTL(zipReader) | |
if err != nil { | |
return nil | |
} | |
return triangles | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Command data_source generates data to be piped into a | |
// training program. | |
// | |
// Each line of data corresponds to a different model, and | |
// is a sequence of points and corresponding signed | |
// distance function. For example, this line: | |
// | |
// 0,1,2,0.1 0,1,2.2,-0.1 | |
// | |
// In the above line, the point at (x=0, y=1, z=2) is 0.1 | |
// units outside the surface of the model. | |
// The point at (x=0, y=1, z=2.2) is 0.1 units inside the | |
// surface of the model. | |
package main | |
import ( | |
"flag" | |
"fmt" | |
"log" | |
"math/rand" | |
"os" | |
"path/filepath" | |
"runtime" | |
"strings" | |
"sync" | |
"time" | |
"github.com/unixpickle/essentials" | |
"github.com/unixpickle/model3d" | |
) | |
func main() { | |
rand.Seed(time.Now().UnixNano()) | |
var testing bool | |
var cacheGrouping bool | |
var samplesPerModel int | |
flag.Usage = func() { | |
fmt.Fprintf(os.Stderr, "Usage: %s [flags] <data directory>\n", os.Args[0]) | |
fmt.Fprintln(os.Stderr, "Flags:") | |
flag.PrintDefaults() | |
} | |
flag.BoolVar(&testing, "testing", false, "a flag to enable the test set") | |
flag.BoolVar(&cacheGrouping, "cache", false, "use cache files to speed up generation") | |
flag.IntVar(&samplesPerModel, "samples", 1000, "point samples per 3D model") | |
flag.Parse() | |
if len(flag.Args()) != 1 { | |
flag.Usage() | |
os.Exit(1) | |
} | |
dataDir := flag.Args()[0] | |
log.Printf("Locating model files in %s...", dataDir) | |
modelFiles := ModelFiles(dataDir, testing) | |
log.Printf("Found %d model files.", len(modelFiles)) | |
for i := 0; true; i++ { | |
log.Printf("Running epoch %d...", i) | |
rand.Shuffle(len(modelFiles), func(i, j int) { | |
modelFiles[i], modelFiles[j] = modelFiles[j], modelFiles[i] | |
}) | |
fileChan := make(chan string, len(modelFiles)) | |
for _, f := range modelFiles { | |
fileChan <- f | |
} | |
close(fileChan) | |
var wg sync.WaitGroup | |
var printLock sync.Mutex | |
for i := 0; i < runtime.GOMAXPROCS(8); i++ { | |
wg.Add(1) | |
go func() { | |
defer wg.Done() | |
for file := range fileChan { | |
collider, err := LoadCollider(file, cacheGrouping) | |
if err != nil { | |
log.Println(err) | |
continue | |
} | |
str := ColliderToString(collider, samplesPerModel) | |
printLock.Lock() | |
fmt.Println(str) | |
printLock.Unlock() | |
} | |
}() | |
} | |
wg.Wait() | |
} | |
} | |
func ModelFiles(dir string, testing bool) []string { | |
var result []string | |
essentials.Must(filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { | |
if filepath.Base(path) == "test" && !testing { | |
return filepath.SkipDir | |
} else if filepath.Base(path) == "train" && testing { | |
return filepath.SkipDir | |
} | |
if !info.IsDir() && filepath.Ext(path) == ".off" { | |
result = append(result, path) | |
} | |
return nil | |
})) | |
return result | |
} | |
func ColliderToString(collider model3d.Collider, samples int) string { | |
var sphereStrs []string | |
for i := 0; i < samples; i++ { | |
c := model3d.Coord3D{X: rand.Float64(), Y: rand.Float64(), Z: rand.Float64()} | |
distance := DistToSurface(collider, c) | |
ray := &model3d.Ray{ | |
Origin: c, | |
Direction: model3d.Coord3D{ | |
X: rand.NormFloat64(), | |
Y: rand.NormFloat64(), | |
Z: rand.NormFloat64(), | |
}, | |
} | |
if collider.RayCollisions(ray)%2 == 1 { | |
distance *= -1 | |
} | |
sphereStrs = append(sphereStrs, fmt.Sprintf("%f,%f,%f,%f", c.X, c.Y, c.Z, distance)) | |
} | |
return strings.Join(sphereStrs, " ") | |
} | |
func DistToSurface(c model3d.Collider, p model3d.Coord3D) float64 { | |
if !c.SphereCollision(p, 1.0) { | |
// Upper bound is 1, since that's how large we | |
// scale models to. | |
return 1.0 | |
} | |
max := 1.0 | |
min := 0.0 | |
for i := 0; i < 10; i++ { | |
x := (max + min) / 2 | |
if c.SphereCollision(p, x) { | |
max = x | |
} else { | |
min = x | |
} | |
} | |
return (min + max) / 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment