Skip to content

Instantly share code, notes, and snippets.

@MattFanto
Last active July 3, 2018 16:20
Show Gist options
  • Save MattFanto/226772faafe7b24d38b1c39370c292b8 to your computer and use it in GitHub Desktop.
Save MattFanto/226772faafe7b24d38b1c39370c292b8 to your computer and use it in GitHub Desktop.
Group by sliding window in tensorflow
length = int(X.shape[0] / step) - window_size
Xt = np.empty((length, window_size*len(GCLOUD_SENSOR_COLS)))
for i in range(length):
Xt[i] = X[i*step:i*step+window_size].ravel()
/*
* Welcome to the Java Apach Beam version!!
*/
package com.dermatrack.dataflow.preprocessing;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.*;
import org.apache.beam.sdk.extensions.gcp.storage.GcsCreateOptions;
import org.apache.beam.sdk.io.*;
import org.apache.beam.sdk.io.fs.EmptyMatchTreatment;
import org.apache.beam.sdk.io.fs.MatchResult;
import org.apache.beam.sdk.io.fs.ResourceId;
import org.apache.beam.sdk.io.gcp.pubsub.PubsubOptions;
import org.apache.beam.sdk.options.*;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.util.CoderUtils;
import org.apache.beam.sdk.util.MimeTypes;
import org.apache.beam.sdk.values.PCollection;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.SeekableByteChannel;
import java.nio.channels.WritableByteChannel;
import java.util.ArrayList;
import java.util.List;
import java.util.NoSuchElementException;
/**
* Keep all meged csv files and aplly group by sliding window transformation
*/
public class PreProcessing {
/**
*/
interface Options extends PubsubOptions {
@Description("Output file bucket location")
@Default.String("gs://dermatrack-mlengine/cleaned_data/regr_filt_augm/v1/w50/")
@Validation.Required
String getOutputDir();
void setOutputDir(String value);
@Description("Size of the sliding window")
@Default.Integer(50)
int getWindowSize();
void setWindowSize(int value);
@Description("Num step for each sliding window")
@Default.Integer(5)
int getStepSize();
void setStepSize(int value);
}
/**
* Class utilities to read line
*/
private static class LineReader {
private ReadableByteChannel channel = null;
private long nextLineStart = 0;
private long currentLineStart = 0;
private final ByteBuffer buf;
private static final int BUF_SIZE = 1024;
private String currentValue = null;
public LineReader(final ReadableByteChannel channel)
throws IOException {
buf = ByteBuffer.allocate(BUF_SIZE);
buf.flip();
boolean removeLine = false;
// If we are not at the beginning of a line, we should ignore the current line.
if (channel instanceof SeekableByteChannel) {
SeekableByteChannel seekChannel = (SeekableByteChannel) channel;
if (seekChannel.position() > 0) {
// Start from one character back and read till we find a new line.
seekChannel.position(seekChannel.position() - 1);
removeLine = true;
}
nextLineStart = seekChannel.position();
}
this.channel = channel;
if (removeLine) {
nextLineStart += readNextLine(new ByteArrayOutputStream());
}
}
private int readNextLine(final ByteArrayOutputStream out) throws IOException {
int byteCount = 0;
while (true) {
if (!buf.hasRemaining()) {
buf.clear();
int read = channel.read(buf);
if (read < 0) {
break;
}
buf.flip();
}
byte b = buf.get();
byteCount++;
if (b == '\n') {
break;
}
out.write(b);
}
return byteCount;
}
public boolean readNextLine() throws IOException {
currentLineStart = nextLineStart;
ByteArrayOutputStream buf = new ByteArrayOutputStream();
int offsetAdjustment = readNextLine(buf);
if (offsetAdjustment == 0) {
// EOF
return false;
}
nextLineStart += offsetAdjustment;
// When running on Windows, each line obtained from 'readNextLine()' will end with a '\r'
// since we use '\n' as the line boundary of the reader. So we trim it off here.
currentValue = CoderUtils.decodeFromByteArray(StringUtf8Coder.of(), buf.toByteArray()).trim();
return true;
}
public String getCurrent() {
return currentValue;
}
public long getCurrentLineStart() {
return currentLineStart;
}
}
/**
* Group list of records into a window, this reduce also multiple labels into a single one as:
* reduced_label = n_itch / total_record
* @param records
* @return
*/
private static String recordsToWindow(String[] records) {
int itch = 0;
StringBuilder window = new StringBuilder();
for (String rec : records) {
String[] values = rec.split(",");
String label = values[values.length-1];
if (label == "itch") {
itch += 1;
}
window.append(rec.substring(0, rec.length() - label.length()));
}
window.append(itch / records.length);
return window.toString() + "\n";
}
public static void main(String[] args) throws Exception {
Options options = PipelineOptionsFactory.fromArgs(args).withValidation().as(Options.class);
// Enforce that this pipeline is always run in streaming mode.
options.setStreaming(true);
final int windowLen = options.getWindowSize();
final int stepSize = options.getStepSize();
final String ouputDir = options.getOutputDir();
Pipeline pipeline = Pipeline.create(options);
MatchResult ms = FileSystems.match("gs://dermatrack-mlengine/cleaned_data/merged/*.csv");
PCollection<MatchResult.Metadata> filesMetadata = pipeline.apply("RetrieveFileList", Create.of(ms.metadata()));
filesMetadata
.apply("GroupBySlidingWindow",
ParDo.of(new DoFn<MatchResult.Metadata, String>() {
@ProcessElement
public void processElement(ProcessContext c) throws IOException {
System.out.println(c.element());
ReadableByteChannel channel = FileSystems.open(c.element().resourceId());
LineReader lineReader = new LineReader(channel);
List<String> windows = new ArrayList<String>();
int outputSize = 0;
// First window deserve a special treatment
String[] records = new String[windowLen];
for (int i = 0; i < windowLen; i++) {
if (!lineReader.readNextLine()) {
return;
} else {
records[i] = lineReader.getCurrent();
outputSize += records[i].getBytes().length;
}
}
windows.add(recordsToWindow(records));
// Following window can be optimize by previously readed records
while (true) {
String[] new_records = new String[windowLen];
for (int i = 0; i < windowLen - stepSize; i++) {
new_records[i] = records[i + stepSize];
}
boolean endOfFile = false;
for (int i = windowLen - stepSize; i < windowLen; i++) {
if (!lineReader.readNextLine()) {
endOfFile = true;
break;
} else {
new_records[i] = lineReader.getCurrent();
outputSize += new_records[i].getBytes().length;
}
}
if (!endOfFile) {
records = new_records;
windows.add(recordsToWindow(records));
}
// Not enough record to group into this window just break
else {
break;
}
}
// Write window to blob
String outFileName = ouputDir + c.element().resourceId().getFilename();
WritableByteChannel writeChannel = FileSystems.create(
FileSystems.matchNewResource(outFileName, false), GcsCreateOptions.builder().setMimeType(MimeTypes.BINARY).build());
for (String rec : windows) {
byte [] record = rec.getBytes();
ByteBuffer byteBuffer = ByteBuffer.allocate(record.length);
byteBuffer.put(record);
byteBuffer.position(0);
writeChannel.write(byteBuffer);
}
writeChannel.close();
}
}
)
);
pipeline.run().waitUntilFinish();
}
}
def filter_overlapping_values(x, window_size):
s1 = tf.slice(x[0], [window_size//2, 0], [-1, -1])
s2 = tf.slice(x[1], [0, 0], [window_size//2, -1])
return tf.concat((s1, s2), axis=0)
length = 12
components = np.array([[i] for i in range(length)], dtype=np.int64)
# components = np.arange(6 * 4, dtype=np.int64).reshape((-1, 4))
dataset = dataset_ops.Dataset.from_tensor_slices(components)
window_size = 4
# window consecutive elements with batch
dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(window_size))
# [[0][1][2][3]]
# [[4][5][6][7]]
# [[8][9][10][11]]
# Skip first row and duplicate all rows, this allows the creation of overlapping window
dataset1 = dataset.apply(tf.contrib.data.group_by_window(lambda x: 0, lambda k, d: d.repeat(2), window_size=1)).skip(1)
# [[0][1][2][3]]
# [[4][5][6][7]]
# [[4][5][6][7]]
# [[8][9][10][11]]
# [[8][9][10][11]]
# Use batch to merge duplicate rows into a single row with both value from window(i) and window(i+1)
dataset1 = dataset1.apply(tf.contrib.data.batch_and_drop_remainder(2))
# [ [[0][1][2][3]] [[4][5][6][7]] ]
# [ [[4][5][6][7]] [[8][9][10][11]] ]
# filter with slice only useful values for overlapping windows
dataset1 = dataset1.map(lambda x: filter_overlapping_values(x, window_size))
# [[2][3][4][5]]
# [[6][7][8][9]]
# Now insert overlapping window into the dataset at the right position
dataset = tf.data.Dataset.zip((dataset, dataset1))
# x0: [[0][1][2][3]] x1: [[2][3][4][5]]
# x0: [[4][5][6][7]] x1: [[6][7][8][9]]
# Flat the dataset with original window and the dataset with overlapping window into a single dataset and flat it
dataset = dataset.flat_map(lambda x0, x1: tf.data.Dataset.from_tensors(x0).concatenate(tf.data.Dataset.from_tensors(x1)))
# [[0][1][2][3]]
# [[2][3][4][5]]
# [[4][5][6][7]]
# [[6][7][8][9]]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment