Skip to content

Instantly share code, notes, and snippets.

@emillynge
Created August 16, 2018 09:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save emillynge/86369a51d92d2032ef53c9350f458c4e to your computer and use it in GitHub Desktop.
Save emillynge/86369a51d92d2032ef53c9350f458c4e to your computer and use it in GitHub Desktop.
DataBuffer length issue v2
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
public class DataBufferLengthIssueShowcase
{
public static void printArrayAndBuffer(INDArray arr, String title) {
printArrayAndBuffer(arr, title, false);
}
public static void printArrayAndBuffer(INDArray arr, String title, boolean printArray){
if (title != null){
int leftPad = 25 - title.length() / 2;
int rightPad = 50 - leftPad - title.length();
for (int i = 0; i++ < leftPad;) System.out.print("=");
System.out.print(" " + title + " ");
for (int i = 0; i++ < rightPad;) System.out.print("=");
System.out.print("\n");
}
if (printArray) System.out.println(arr.toString());
System.out.println(arr.shapeInfoToString());
System.out.println(String.format("length of array: %d, length of databuffer: %d", arr.length(), arr.data().length()));
System.out.println(String.format("offset of first element: %d\tindex of last element: %d", arr.offset(), offsetOfLastElement(arr)));
}
public static long offsetOfLastElement(INDArray arr){
long offset = arr.offset();
for (int i = 0; i < arr.rank(); ++i){
offset += arr.stride(i) * (arr.shape()[i] - 1);
}
return offset;
}
public static void main(String[] argv){
INDArray m = Nd4j.linspace(0,99,100).reshape(10,10);
printArrayAndBuffer(m, "Initial array");
printArrayAndBuffer(m.getRow(0), "first row");
printArrayAndBuffer(m.getColumn(0), "first column");
printArrayAndBuffer(m.getRow(5), "mid row");
printArrayAndBuffer(m.getColumn(5), "mid column");
INDArray m3 = m.reshape(5,4,5);
INDArray subArr = m3.get(NDArrayIndex.interval(2,4), NDArrayIndex.interval(1,3), NDArrayIndex.all());
INDArray subArrPerm = subArr.permute(2,0,1);
printArrayAndBuffer(subArrPerm, "3D permuted", true);
printArrayAndBuffer(subArrPerm.convertToDoubles(), "3D permuted and converted", true);
}
}
=================== Initial array ==================
Rank: 2,Offset: 0
Order: c Shape: [10,10], stride: [10,1]
length of array: 100, length of databuffer: 100
offset of first element: 0 index of last element: 99
===================== first row ====================
Rank: 2,Offset: 0
Order: c Shape: [1,10], stride: [1,1]
length of array: 10, length of databuffer: 100
offset of first element: 0 index of last element: 9
=================== first column ===================
Rank: 2,Offset: 0
Order: c Shape: [10,1], stride: [10,1]
length of array: 10, length of databuffer: 100
offset of first element: 0 index of last element: 90
====================== mid row =====================
Rank: 2,Offset: 0
Order: c Shape: [1,10], stride: [1,1]
length of array: 10, length of databuffer: 10
offset of first element: 50 index of last element: 59
==================== mid column ====================
Rank: 2,Offset: 0
Order: c Shape: [10,1], stride: [10,1]
length of array: 10, length of databuffer: 10
offset of first element: 5 index of last element: 95
==================== 3D permuted ===================
[[[ 45.0000, 50.0000],
[ 65.0000, 70.0000]],
[[ 46.0000, 51.0000],
[ 66.0000, 71.0000]],
[[ 47.0000, 52.0000],
[ 67.0000, 72.0000]],
[[ 48.0000, 53.0000],
[ 68.0000, 73.0000]],
[[ 49.0000, 54.0000],
[ 69.0000, 74.0000]]]
Rank: 3,Offset: 0
Order: c Shape: [5,2,2], stride: [1,20,5]
length of array: 20, length of databuffer: 20
offset of first element: 45 index of last element: 74
============= 3D permuted and converted ============
Exception in thread "main" java.lang.IndexOutOfBoundsException: 20
at org.bytedeco.javacpp.indexer.Indexer.checkIndex(Indexer.java:90)
at org.bytedeco.javacpp.indexer.DoubleRawIndexer.get(DoubleRawIndexer.java:59)
at org.nd4j.linalg.api.buffer.BaseDataBuffer.getDouble(BaseDataBuffer.java:1042)
at org.nd4j.linalg.api.ndarray.BaseNDArray.getDouble(BaseNDArray.java:4420)
at org.nd4j.linalg.string.NDArrayStrings.vectorToString(NDArrayStrings.java:223)
at org.nd4j.linalg.string.NDArrayStrings.format(NDArrayStrings.java:173)
at org.nd4j.linalg.string.NDArrayStrings.format(NDArrayStrings.java:196)
at org.nd4j.linalg.string.NDArrayStrings.format(NDArrayStrings.java:196)
at org.nd4j.linalg.string.NDArrayStrings.format(NDArrayStrings.java:142)
at org.nd4j.linalg.string.NDArrayStrings.format(NDArrayStrings.java:121)
at org.nd4j.linalg.api.ndarray.BaseNDArray.toString(BaseNDArray.java:5722)
at DataBufferLengthIssueShowcase.printArrayAndBuffer(DataBufferLengthIssueShowcase.java:20)
at DataBufferLengthIssueShowcase.main(DataBufferLengthIssueShowcase.java:50)
Disconnected from the target VM, address: '127.0.0.1:34963', transport: 'socket'
Process finished with exit code 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment