Skip to content

Instantly share code, notes, and snippets.

@raphlinus
Created April 23, 2018 19:31
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save raphlinus/591b7708959ebdc97759b560e215dd26 to your computer and use it in GitHub Desktop.
Save raphlinus/591b7708959ebdc97759b560e215dd26 to your computer and use it in GitHub Desktop.
sketch of prefix sum to do backslash unescaping in cuda
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "stdio.h"
#include "stdint.h"
#include <ctime>
#include <vector>
#include "thrust/device_vector.h"
#include "thrust/copy.h"
#include "thrust/transform_scan.h"
struct map_fsm {
__device__ uint8_t operator()(uint8_t in) {
if (in == '"') {
return 0xd1;
} else if (in == '\\') {
return 0xdb;
} else {
return 0xd4;
}
}
};
struct compose_fsm {
__device__ uint8_t operator()(uint8_t a, uint8_t b) {
return ((b >> ((a << 1) & 6)) & 3) |
(((b >> ((a >> 1) & 6)) & 3) << 2) |
(((b >> ((a >> 3) & 6)) & 3) << 4) |
0xc0;
}
};
struct keep {
__device__ bool operator()(uint8_t a) {
return (a & 3) == 1;
}
};
size_t scalar(const uint8_t* inp, size_t n, uint8_t* out) {
int state = 0;
size_t j = 0;
for (size_t i = 0; i < n; i++) {
uint8_t b = inp[i];
if ((state == 0 && b == '"') || state == 2) {
state = 1;
} else if (state == 1 && b == '\\') {
state = 2;
} else if (state == 1 && b == '"') {
state = 0;
}
if (state == 1) {
out[j++] = b;
}
}
return j;
}
int main()
{
const int n_copies = 10000000;
const char* input = "\"string with quote (\\\") and backslash(\\\\)\"";
size_t input_len = strlen(input);
size_t total_len = input_len * n_copies;
printf("input: %s, total size: %d\n", input, (int)total_len);
thrust::device_vector<uint8_t> inp(total_len);
std::vector<uint8_t> rep_input(total_len);
for (int i = 0; i < n_copies; i++) {
std::copy((const uint8_t*)input, (const uint8_t*)input + input_len, rep_input.begin() + i * input_len);
}
#if 1
thrust::copy(rep_input.begin(), rep_input.end(), inp.begin());
std::clock_t start = std::clock();
thrust::device_vector<uint8_t> scan(total_len);
thrust::transform_inclusive_scan(inp.begin(), inp.end(), scan.begin(), map_fsm{}, compose_fsm{});
thrust::device_vector<uint8_t> out(total_len);
auto end = thrust::copy_if(inp.begin(), inp.end(), scan.begin(), out.begin(), keep{});
#else
std::clock_t start = std::clock();
std::vector<uint8_t> out(total_len);
auto end = out.begin() + scalar(rep_input.data(), total_len, out.data());
#endif
double elapsed = (std::clock() - start) / (double) CLOCKS_PER_SEC;
double throughput = total_len / elapsed * 1e-6;
printf("elapsed: %g (%gMB/s)\n", elapsed, throughput);
printf("result: ");
int count = 0;
for (auto it = out.begin(); it != end; it++) {
uint8_t b = *it;
printf("%c", b);
if (count++ == 60) break;
}
printf("\n");
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment