Last active
August 16, 2019 05:36
-
-
Save momvart/19cba7bcad5b282931b8f8b2c78af924 to your computer and use it in GitHub Desktop.
A set of utility classes to read MNIST IDX files. http://yann.lecun.com/exdb/mnist/
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.ArrayList; | |
import java.util.Arrays; | |
import java.util.Collections; | |
import java.util.List; | |
import java.util.stream.Collectors; | |
public class IDXData { | |
protected final ArrayList<Integer> dimensions; | |
protected final byte[][] rawData; | |
public IDXData(int[] dimensions, byte[][] rawData) { | |
this(Arrays.stream(dimensions).boxed().collect(Collectors.toList()), rawData); | |
} | |
public IDXData(List<Integer> dimensions, byte[][] rawData) { | |
this.dimensions = new ArrayList<>(dimensions); | |
this.dimensions.trimToSize(); | |
this.rawData = rawData; | |
} | |
public List<Integer> getDimensions() { | |
return Collections.unmodifiableList(dimensions); | |
} | |
public byte[][] getAllRawElements() { | |
return rawData; | |
} | |
/** | |
* A shortcut method for 1D data to improve the performance | |
*/ | |
public byte[] getRawElement(int i) { | |
if (dimensions.size() != 2) | |
throw new InconsistentDimensionsException(dimensions.size()); | |
return rawData[i]; | |
} | |
/** | |
* A shortcut method for 2D data to improve the performance | |
*/ | |
public byte[] getRawElement(int x, int y) { | |
if (dimensions.size() != 2) | |
throw new InconsistentDimensionsException(dimensions.size()); | |
return rawData[x * dimensions.get(1) + y]; | |
} | |
public byte[] getRawElement(int... location) { | |
if (location.length != dimensions.size()) | |
throw new InconsistentDimensionsException(dimensions.size()); | |
int index = 0; | |
for (int i = 0; i < dimensions.size(); i++) | |
index = index * dimensions.get(i) + location[i]; | |
return rawData[index]; | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.Arrays; | |
public enum IDXDataType { | |
UBYTE(0x08, 1, Byte.class), | |
SBYTE(0x09, 1, Byte.class), | |
SHORT(0x0B, 2, Short.class), | |
INT(0x0C, 4, Integer.class), | |
FLOAT(0x0D, 4, Float.class), | |
DOUBLE(0x0E, 8, Double.class); | |
private final int value; | |
private final int size; | |
private final Class<? extends Number> javaClass; | |
IDXDataType(int value, int size, Class<? extends Number> javaClass) { | |
this.value = value; | |
this.size = size; | |
this.javaClass = javaClass; | |
} | |
public static IDXDataType getByTypeValue(int value) { | |
return Arrays.stream(IDXDataType.values()) | |
.filter(idxDataType -> idxDataType.value == value) | |
.findFirst() | |
.orElse(null); | |
} | |
public int getSize() { | |
return size; | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.io.Closeable; | |
import java.io.DataInputStream; | |
import java.io.IOException; | |
import java.io.InputStream; | |
import java.util.Arrays; | |
import java.util.stream.IntStream; | |
public class IDXReader implements Closeable { | |
private DataInputStream input; | |
private IDXDataType dataType; | |
private int[] dimensions; | |
public IDXReader(InputStream input) throws IOException { | |
this.input = new DataInputStream(input); | |
init(); | |
} | |
private void init() throws IOException { | |
readMagic(); | |
for (int i = 0; i < dimensions.length; i++) | |
dimensions[i] = input.readInt(); | |
} | |
private void readMagic() throws IOException { | |
//Skipping the first two zero bytes | |
input.readByte(); | |
input.readByte(); | |
dataType = IDXDataType.getByTypeValue(input.readUnsignedByte()); | |
dimensions = new int[input.readUnsignedByte()]; | |
} | |
public int[] getDimensions() { | |
return Arrays.copyOf(dimensions, dimensions.length); | |
} | |
public IDXData nextData(int dimensionsCount) throws IOException { | |
int readCount = IntStream.range(dimensions.length - dimensionsCount, dimensions.length) | |
.map(i -> dimensions[i]) | |
.reduce(1, (left, right) -> left * right); | |
byte[][] raw = new byte[readCount][dataType.getSize()]; | |
for (byte[] element : raw) | |
if (input.read(element) != element.length) | |
throw new IOException(); | |
return new IDXData(Arrays.copyOfRange(dimensions, dimensions.length - dimensionsCount, dimensions.length) | |
, raw); | |
} | |
@Override | |
public void close() throws IOException { | |
input.close(); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.Arrays; | |
import java.util.List; | |
import java.util.stream.IntStream; | |
public class IDXUnsignedByteData { | |
private static int convertToUnsignedByte(byte signed) { | |
return 0xff & signed; | |
} | |
private IDXData base; | |
public IDXUnsignedByteData(IDXData base) { | |
this.base = base; | |
} | |
public int[] getAllElements() { | |
return getAllElementsStream().toArray(); | |
} | |
public IntStream getAllElementsStream() { | |
return Arrays.stream(getAllRawElements()) | |
.mapToInt(bytes -> convertToUnsignedByte(bytes[0])); | |
} | |
/** | |
* A shortcut method for 1D data to improve the performance | |
*/ | |
public int getElement(int i) { | |
byte[] raw = getRawElement(i); | |
return convertToUnsignedByte(raw[0]); | |
} | |
/** | |
* A shortcut method for 2D data to improve the performance | |
*/ | |
public int getElement(int i, int j) { | |
byte[] raw = getRawElement(i, j); | |
return convertToUnsignedByte(raw[0]); | |
} | |
public int getElement(int... locations) { | |
byte[] raw = getRawElement(locations); | |
return convertToUnsignedByte(raw[0]); | |
} | |
public List<Integer> getDimensions() { | |
return base.getDimensions(); | |
} | |
public byte[][] getAllRawElements() { | |
return base.getAllRawElements(); | |
} | |
public byte[] getRawElement(int i) { | |
return base.getRawElement(i); | |
} | |
public byte[] getRawElement(int i, int j) { | |
return base.getRawElement(i, j); | |
} | |
public byte[] getRawElement(int... location) { | |
return base.getRawElement(location); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public class InconsistentDimensionsException extends RuntimeException { | |
public InconsistentDimensionsException(int expected) { | |
super("Coordinates must be in size of " + expected); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment