Skip to content

Instantly share code, notes, and snippets.

@cogmission
Created January 29, 2016 21:01
Show Gist options
  • Save cogmission/e7b06f5ab545a80b19be to your computer and use it in GitHub Desktop.
Save cogmission/e7b06f5ab545a80b19be to your computer and use it in GitHub Desktop.
SDR RoaringBitmap wrapper api so far
package org.numenta.nupic.research.sdrs;
import java.util.Arrays;
import java.util.Iterator;
import java.util.stream.IntStream;
import org.roaringbitmap.RoaringBitmap;
public class SDR implements Iterable<Integer> {
private RoaringBitmap bits;
private int size;
private SDR(int width, int... bits) {
this.size = width;
this.bits = RoaringBitmap.bitmapOf(bits);
}
/**
* Returns an {@code SDR} constructed from <b>sparse</b> indices
* ([1,2,7,11] <b><em>NOT</em></b> [0,1,1,0,0,0,0,1,0,0,0,1])
* please be careful as the interface doesn't warn you for misuse
* for efficiency sake.
*
* <b>Warning:</b> The first parameter <b>must</b> be the total length
*
* @param positions
* @return
*/
public static SDR fromSparse(int width, int... positions) {
return new SDR(width, positions);
}
/**
* Returns a <b><em>sorted</em></b> {@code SDR} constructed from <b>sparse</b> indices
* ([7,2,11,1] (unsorted) <b><em>NOT</em></b> [0,1,1,0,0,0,0,1,0,0,0,1])
* please be careful as the interface doesn't warn you for misuse
* for efficiency sake.
*
* * <b>Warning:</b> The first parameter <b>must</b> be the total length
*
* @param positions
* @return
*/
public static SDR fromUnsortedSparse(int width, int... positions) {
return new SDR(width, Arrays.stream(positions).sorted().toArray());
}
/**
* Returns an {@code SDR} constructed from <b>dense</b> indices
* ([0,1,1,0,0,0,0,1,0,0,0,1] <b><em>NOT</em></b> [1,2,7,11])
* please be careful as the interface doesn't warn you for misuse
* for efficiency sake.
* @param positions
* @return
*/
public static SDR fromDense(int... bits) {
return new SDR(bits.length, IntStream.range(0, bits.length).filter(i -> bits[i] == 1).toArray());
}
/**
* Returns the total width of this {@code SDR} (counting zero bits as well).
* @return
*/
public int size() {
return size;
}
/**
* Returns the bit value at the absolute position specified.
*
* (i.e. for: [1,2,4,7,11] --> getValue(8) == 0 where as, getValue(...)
* with 1, 2, 4, 7, or 11 will return a 1).
*
* @param denseIndex the index for which a value (either 0 or 1) is returned.
* @return
*/
public int getValue(int denseIndex) {
return bits.contains(denseIndex) ? 1 : 0;
}
/**
* Returns the position (sparse index) at the position specified.
* (i.e. for: [1,2,4,7,11] --> getPosition(8) throws {@link IndexOutOfBoundsException}
* and getPosition() with 0, 1, 2, 3, or 4 will return 1, 2, 4, 7, 11 respectively
*
* @param position the position of the index to return
* @return
*/
public int getPosition(int position) {
return bits.select(position);
}
/**
* Returns an {@link Iterator} over the bits in this {@link SDR}
*/
@Override
public Iterator<Integer> iterator() {
return bits.iterator();
}
/**
* @see java.lang.Object#hashCode()
*/
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + ((bits == null) ? 0 : bits.hashCode());
return result;
}
/**
* @see java.lang.Object#equals(java.lang.Object)
*/
@Override
public boolean equals(Object obj) {
if(this == obj)
return true;
if(obj == null)
return false;
if(getClass() != obj.getClass())
return false;
SDR other = (SDR)obj;
if(bits == null) {
if(other.bits != null)
return false;
} else if(!bits.equals(other.bits))
return false;
return true;
}
}
@cogmission
Copy link
Author

Here's the test:

package org.numenta.nupic.research.sdrs;

import static org.junit.Assert.*;

import org.junit.Test;


public class SDRTest {

    @Test
    public void testConstructionAndEquals() {
        SDR sdr1 = SDR.fromSparse(12, 1, 2, 7, 11);
        SDR sdr2 = SDR.fromSparse(12, new int[] { 1, 2, 7, 11 });
        assertTrue(sdr1.equals(sdr2));

        SDR sdr3 = SDR.fromSparse(12, 1, 3, 7, 11); // 3 is different from 2
        SDR sdr4 = SDR.fromSparse(12, new int[] { 1, 2, 7, 11 });
        assertFalse(sdr3.equals(sdr4));

        SDR sdr5 = SDR.fromDense(0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1);
        SDR sdr6 = SDR.fromDense(new int[] { 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1 });
        assertTrue(sdr5.equals(sdr6));

        SDR sdr7 = SDR.fromDense(0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1); // 3 is different from 2
        SDR sdr8 = SDR.fromDense(new int[] { 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1 });
        assertFalse(sdr7.equals(sdr8));

        SDR sdr9 = SDR.fromUnsortedSparse(12, 7, 2, 11, 1);
        SDR sdr10 = SDR.fromSparse(12, new int[] { 1, 2, 7, 11 });
        assertTrue(sdr9.equals(sdr10));

        // Test size
        assertTrue(sdr1.size() == sdr8.size());
    }

    @Test
    public void testGetIndexes() {
        SDR sdr1 = SDR.fromSparse(12, 1, 2, 7, 11);
        int idx = 0;
        int[] indexes = { 1, 2, 7, 11 };


        for(Integer i : sdr1) {
            assertEquals(i.intValue(), indexes[idx++]);
        }

        idx = 0;
        int[] denseArray = { 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1 };
        SDR sdr2 = SDR.fromDense(0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1);
        for(Integer i : sdr2) {
            assertEquals(i.intValue(), indexes[idx++]);
        }

        assertEquals(0, sdr2.getValue(8));
        assertEquals(0, sdr1.getValue(8));

        try {
            assertEquals(10000000, sdr1.getPosition(8)); // don't really care about the value 
            fail();
        }catch(Exception e) {
            assertEquals("select 8 when the cardinality is 4", e.getMessage());
            assertEquals(IllegalArgumentException.class, e.getClass());
        }

        // Random access by cardinality
        for(int i = 0;i < sdr1.cardinality();i++) {
            assertEquals(indexes[i], sdr1.getPosition(i));
        }

        // Random access by size
        for(int i = 0;i < sdr1.size();i++) {
            assertEquals(denseArray[i], sdr1.getValue(i));
        }
    }

}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment