Skip to content

Instantly share code, notes, and snippets.

@momvart
Last active August 16, 2019 05:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save momvart/19cba7bcad5b282931b8f8b2c78af924 to your computer and use it in GitHub Desktop.
Save momvart/19cba7bcad5b282931b8f8b2c78af924 to your computer and use it in GitHub Desktop.
A set of utility classes to read MNIST IDX files. http://yann.lecun.com/exdb/mnist/
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
public class IDXData {
protected final ArrayList<Integer> dimensions;
protected final byte[][] rawData;
public IDXData(int[] dimensions, byte[][] rawData) {
this(Arrays.stream(dimensions).boxed().collect(Collectors.toList()), rawData);
}
public IDXData(List<Integer> dimensions, byte[][] rawData) {
this.dimensions = new ArrayList<>(dimensions);
this.dimensions.trimToSize();
this.rawData = rawData;
}
public List<Integer> getDimensions() {
return Collections.unmodifiableList(dimensions);
}
public byte[][] getAllRawElements() {
return rawData;
}
/**
* A shortcut method for 1D data to improve the performance
*/
public byte[] getRawElement(int i) {
if (dimensions.size() != 2)
throw new InconsistentDimensionsException(dimensions.size());
return rawData[i];
}
/**
* A shortcut method for 2D data to improve the performance
*/
public byte[] getRawElement(int x, int y) {
if (dimensions.size() != 2)
throw new InconsistentDimensionsException(dimensions.size());
return rawData[x * dimensions.get(1) + y];
}
public byte[] getRawElement(int... location) {
if (location.length != dimensions.size())
throw new InconsistentDimensionsException(dimensions.size());
int index = 0;
for (int i = 0; i < dimensions.size(); i++)
index = index * dimensions.get(i) + location[i];
return rawData[index];
}
}
import java.util.Arrays;
public enum IDXDataType {
UBYTE(0x08, 1, Byte.class),
SBYTE(0x09, 1, Byte.class),
SHORT(0x0B, 2, Short.class),
INT(0x0C, 4, Integer.class),
FLOAT(0x0D, 4, Float.class),
DOUBLE(0x0E, 8, Double.class);
private final int value;
private final int size;
private final Class<? extends Number> javaClass;
IDXDataType(int value, int size, Class<? extends Number> javaClass) {
this.value = value;
this.size = size;
this.javaClass = javaClass;
}
public static IDXDataType getByTypeValue(int value) {
return Arrays.stream(IDXDataType.values())
.filter(idxDataType -> idxDataType.value == value)
.findFirst()
.orElse(null);
}
public int getSize() {
return size;
}
}
import java.io.Closeable;
import java.io.DataInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Arrays;
import java.util.stream.IntStream;
public class IDXReader implements Closeable {
private DataInputStream input;
private IDXDataType dataType;
private int[] dimensions;
public IDXReader(InputStream input) throws IOException {
this.input = new DataInputStream(input);
init();
}
private void init() throws IOException {
readMagic();
for (int i = 0; i < dimensions.length; i++)
dimensions[i] = input.readInt();
}
private void readMagic() throws IOException {
//Skipping the first two zero bytes
input.readByte();
input.readByte();
dataType = IDXDataType.getByTypeValue(input.readUnsignedByte());
dimensions = new int[input.readUnsignedByte()];
}
public int[] getDimensions() {
return Arrays.copyOf(dimensions, dimensions.length);
}
public IDXData nextData(int dimensionsCount) throws IOException {
int readCount = IntStream.range(dimensions.length - dimensionsCount, dimensions.length)
.map(i -> dimensions[i])
.reduce(1, (left, right) -> left * right);
byte[][] raw = new byte[readCount][dataType.getSize()];
for (byte[] element : raw)
if (input.read(element) != element.length)
throw new IOException();
return new IDXData(Arrays.copyOfRange(dimensions, dimensions.length - dimensionsCount, dimensions.length)
, raw);
}
@Override
public void close() throws IOException {
input.close();
}
}
import java.util.Arrays;
import java.util.List;
import java.util.stream.IntStream;
public class IDXUnsignedByteData {
private static int convertToUnsignedByte(byte signed) {
return 0xff & signed;
}
private IDXData base;
public IDXUnsignedByteData(IDXData base) {
this.base = base;
}
public int[] getAllElements() {
return getAllElementsStream().toArray();
}
public IntStream getAllElementsStream() {
return Arrays.stream(getAllRawElements())
.mapToInt(bytes -> convertToUnsignedByte(bytes[0]));
}
/**
* A shortcut method for 1D data to improve the performance
*/
public int getElement(int i) {
byte[] raw = getRawElement(i);
return convertToUnsignedByte(raw[0]);
}
/**
* A shortcut method for 2D data to improve the performance
*/
public int getElement(int i, int j) {
byte[] raw = getRawElement(i, j);
return convertToUnsignedByte(raw[0]);
}
public int getElement(int... locations) {
byte[] raw = getRawElement(locations);
return convertToUnsignedByte(raw[0]);
}
public List<Integer> getDimensions() {
return base.getDimensions();
}
public byte[][] getAllRawElements() {
return base.getAllRawElements();
}
public byte[] getRawElement(int i) {
return base.getRawElement(i);
}
public byte[] getRawElement(int i, int j) {
return base.getRawElement(i, j);
}
public byte[] getRawElement(int... location) {
return base.getRawElement(location);
}
}
public class InconsistentDimensionsException extends RuntimeException {
public InconsistentDimensionsException(int expected) {
super("Coordinates must be in size of " + expected);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment