Skip to content

Instantly share code, notes, and snippets.

@t-mat
Forked from easyaspi314/change_case_simd.c
Last active December 19, 2023 02:35
Show Gist options
  • Save t-mat/d6664e97c8b6407088356b3867d43537 to your computer and use it in GitHub Desktop.
Save t-mat/d6664e97c8b6407088356b3867d43537 to your computer and use it in GitHub Desktop.
SIMD functions to apply toupper/tolower to each character in a string
// Created by easyaspi314. Released into the public domain.
// test:$ gcc -msse -DTEST change_case_simd.c && ./a.out
// > cl /EHsc /DTEST change_case_simd.c && .\change_case_simd.exe
#include <ctype.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#ifdef DEMONSTRATE_BASIC_ALGORITHM
#include <stdint.h>
#include <immintrin.h>
//
// To reduce the number of instructions, we can use the following classic
// comparison technique:
//
// int8_t src = ...; // input UTF-8 byte
// int8_t e = (int8_t) (src - 'A' + 0x80);
// int8_t u = (int8_t) ('Z' - 'A' + 0x80 + 1);
// bool isUpper1 = e < u;
//
// note: u = -102 = (0x5a - 0x41 + 0x80 + 1)
//
// This trick works because (in int8_t) it is equivalent to
//
// isUpper1 = (-128 <= e && e < -102)
//
// We can confirm it in the following table:
//
// | src | src-'A'+0x80 | -102 > src-'A'+0x80 |
// | ------- | ------------ | ------------------- |
// | 0xc0 | 0xff = -1 | false |
// | ... | ... | false |
// | '['=0x5b | 0x9a = -102 | false |
// | 'Z'=0x5a | 0x99 = -103 | true |
// | 'Y'=0x59 | 0x98 = -104 | true |
// | ... | ... | true |
// | 'B'=0x42 | 0x81 = -127 | true |
// | 'A'=0x41 | 0x80 = -128 | true |
// | '@'=0x40 | 0x7f = +127 = -129 & 255 | false |
// | ... | ... | false |
// | 0x01 | 0x40 = +64 | false |
// | 0x00 | 0x3f = +63 | false |
// | 0xff | 0x3e = +62 | false |
// | ... | ... | false |
// | 0xc1 | 0x00 = +0 | false |
//
static int8_t CharToLower(int8_t src) {
int8_t o1 = (int8_t)(0x80 - 'A');
int8_t c2 = (int8_t)('Z' - 'A' + 1 + 0x80);
int8_t d3 = (int8_t)('a' - 'A');
int8_t a1 = src + o1;
int8_t a2 = c2 > a1 ? 0xff : 0;
int8_t a3 = a2 & d3;
int8_t a4 = src + a3;
return a4;
}
static int8_t CharToUpper(int8_t src) {
int8_t o1 = (int8_t)(0x80 - 'a');
int8_t c2 = (int8_t)('z' - 'a' + 1 + 0x80);
int8_t d3 = (int8_t)('A' - 'a');
int8_t a1 = src + o1;
int8_t a2 = c2 > a1 ? 0xff : 0;
int8_t a3 = a2 & d3;
int8_t a4 = src + a3;
return a4;
}
// Pseudo SIMD
typedef union U128 {
uint8_t u8[16];
int8_t i8[16];
__m128i m128;
} U128;
static U128 CharToLowerU128(U128 src) {
const int8_t asciiAofs80 = (int8_t)(0x80 - 'A');
const int8_t AtoZp1ofs80 = (int8_t)('Z' - 'A' + 1 + 0x80);
const int8_t diff = (int8_t)('a' - 'A');
U128 dst;
for(int i = 0; i < 16; ++i) {
int8_t a1 = src.i8[i] + asciiAofs80;
int8_t a2 = AtoZp1ofs80 > a1 ? 0xff : 0;
int8_t a3 = a2 & diff;
int8_t a4 = src.i8[i] + a3;
dst.i8[i] = a4;
}
return dst;
}
static U128 CharToUpperU128(U128 src) {
const int8_t asciiAofs80 = (int8_t)(0x80 - 'a');
const int8_t AtoZp1ofs80 = (int8_t)('z' - 'a' + 1 + 0x80);
const int8_t diff = (int8_t)('A' - 'a');
U128 dst;
for(int i = 0; i < 16; ++i) {
int8_t a1 = src.i8[i] + asciiAofs80;
int8_t a2 = AtoZp1ofs80 > a1 ? 0xff : 0;
int8_t a3 = a2 & diff;
int8_t a4 = src.i8[i] + a3;
dst.i8[i] = a4;
}
return dst;
}
#endif
#if defined(__SSE2__) || defined(_M_X64)
#include <immintrin.h>
// A SIMD function for SSE2 which changes all uppercase ASCII digits
// to lowercase.
void StringToLower(char *str)
{
const __m128i asciiAofs80 = _mm_set1_epi8(-'A' + 0x80);
const __m128i AtoZp1ofs80 = _mm_set1_epi8((signed char) ('Z' - 'A' + 1 + 0x80));
const __m128i diff = _mm_set1_epi8('a' - 'A');
size_t len = strlen(str);
while (len >= 16) {
__m128i src = _mm_loadu_si128((__m128i*)str);
__m128i a1 = _mm_add_epi8(src, asciiAofs80);
__m128i a2 = _mm_cmpgt_epi8(AtoZp1ofs80, a1);
__m128i a3 = _mm_and_si128(a2, diff);
__m128i a4 = _mm_add_epi8(src, a3);
_mm_storeu_si128((__m128i *)str, a4);
len -= 16;
str += 16;
}
while (len-- > 0) {
*str = tolower(*str);
++str;
}
}
// Same, but to uppercase.
void StringToUpper(char *str)
{
const __m128i asciiAofs80 = _mm_set1_epi8(-'a' + 0x80);
const __m128i AtoZp1ofs80 = _mm_set1_epi8((signed char) ('z' - 'a' + 1 + 0x80));
const __m128i diff = _mm_set1_epi8('A' - 'a');
size_t len = strlen(str);
while (len >= 16) {
__m128i src = _mm_loadu_si128((__m128i*)str);
__m128i a1 = _mm_add_epi8(src, asciiAofs80);
__m128i a2 = _mm_cmpgt_epi8(AtoZp1ofs80, a1); // Set 0xff if ('Z'-'A') >= a1.u8[i] >= 0, else 0x00
__m128i a3 = _mm_and_si128(a2, diff); // Set 'a'-'A' if ['A', 'Z'], else 0x00
__m128i a4 = _mm_add_epi8(src, a3); //
_mm_storeu_si128((__m128i *)str, a4);
len -= 16;
str += 16;
}
while (len-- > 0) {
*str = toupper(*str);
++str;
}
}
#else
/* Just go scalar. */
void StringToLower(char *str)
{
size_t len = strlen(str);
while (len-- > 0) {
*str = tolower(*str);
++str;
}
}
void StringToUpper(char *str)
{
size_t len = strlen(str);
while (len-- > 0) {
*str = toupper(*str);
++str;
}
}
#endif
#ifdef TEST
#include <assert.h>
static void libc_str_lower(char* str, size_t strLengthInBytes) {
for(size_t i = 0; i < strLengthInBytes; ++i) {
str[i] = tolower(str[i]);
}
}
static void libc_str_upper(char* str, size_t strLengthInBytes) {
for(size_t i = 0; i < strLengthInBytes; ++i) {
str[i] = toupper(str[i]);
}
}
static void hexDump(const void* data, size_t dataSizeInBytes) {
const unsigned char* const p = (const unsigned char*) data;
for(size_t i = 0; i < dataSizeInBytes; ++i) {
if(i % 16 == 0) {
if(i != 0) {
printf("\n");
}
printf("0x%04x: ", (int) i);
}
unsigned char c = p[i];
if(c >= 0x21 && c < 0x7f) {
printf("%c ", c);
} else {
printf("%02x", c);
}
if(i + 1 >= dataSizeInBytes) {
printf("\n");
} else {
printf(" ");
}
}
}
int main()
{
int errorCount = 0;
{
// 0123456789abcdef0123456789abcdef
char str[] = "Hello world 12345 HI ABXYZ abxyz";
StringToLower(str);
puts(str);
if(strcmp(str, "hello world 12345 hi abxyz abxyz") != 0) {
errorCount += 1;
printf("%s(%d) : FAIL\n", __FILE__, __LINE__);
}
StringToUpper(str);
puts(str);
if(strcmp(str, "HELLO WORLD 12345 HI ABXYZ ABXYZ") != 0) {
errorCount += 1;
printf("%s(%d) : FAIL\n", __FILE__, __LINE__);
}
}
{
// 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef
char str[] = "Hello world 12345 HI ABCDEFGHIJKLMNOPQRSTUVWXYZ abcdefghijklmnopqrstuvwxyz #PAD#";
char lwr[] = "hello world 12345 hi abcdefghijklmnopqrstuvwxyz abcdefghijklmnopqrstuvwxyz #pad#";
char upr[] = "HELLO WORLD 12345 HI ABCDEFGHIJKLMNOPQRSTUVWXYZ ABCDEFGHIJKLMNOPQRSTUVWXYZ #PAD#";
StringToLower(str);
puts(str);
if(strcmp(str, lwr) != 0) {
errorCount += 1;
printf("%s(%d) : FAIL\n", __FILE__, __LINE__);
}
StringToUpper(str);
puts(str);
if(strcmp(str, upr) != 0) {
errorCount += 1;
printf("%s(%d) : FAIL\n", __FILE__, __LINE__);
}
}
// All possible chars except 0x00
{
char strSrc[256 + 1];
for(int i = 0; i < sizeof(strSrc); ++i) {
strSrc[i] = (char) (i == 0 ? ' ' : i);
}
strSrc[256] = 0;
for(int iLoop = 0; iLoop <= 1; ++iLoop) {
char libcStr[sizeof(strSrc)];
char simdStr[sizeof(strSrc)];
memcpy(libcStr, strSrc, sizeof(libcStr));
memcpy(simdStr, strSrc, sizeof(simdStr));
switch(iLoop) {
case 0:
libc_str_lower(libcStr, sizeof(libcStr) - 1);
StringToLower(simdStr);
printf("StringToLower:\n");
break;
case 1:
libc_str_upper(libcStr, sizeof(libcStr) - 1);
StringToUpper(simdStr);
printf("StringToUpper:\n");
break;
default:
assert(0);
break;
}
if(memcmp(libcStr, simdStr, sizeof(libcStr)) != 0) {
errorCount += 1;
printf("\n%s(%d) : FAIL\n", __FILE__, __LINE__);
printf("Expected:\n");
hexDump(libcStr, sizeof(libcStr)-1);
printf("Actual:\n");
hexDump(simdStr, sizeof(simdStr)-1);
} else {
hexDump(simdStr, sizeof(simdStr)-1);
}
}
#ifdef DEMONSTRATE_BASIC_ALGORITHM
for(int iLoop = 0; iLoop <= 3; ++iLoop) {
char libcStr[sizeof(strSrc)];
char simdStr[sizeof(strSrc)];
memcpy(libcStr, strSrc, sizeof(libcStr));
memcpy(simdStr, strSrc, sizeof(simdStr));
switch(iLoop) {
case 0:
libc_str_lower(libcStr, sizeof(libcStr) - 1);
for(int i = 0; i < sizeof(simdStr)-1; ++i) {
simdStr[i] = CharToLower(simdStr[i]);
}
StringToLower(simdStr);
printf("StringToLower:\n");
break;
case 1:
libc_str_upper(libcStr, sizeof(libcStr) - 1);
for(int i = 0; i < sizeof(simdStr)-1; ++i) {
simdStr[i] = CharToUpper(simdStr[i]);
}
printf("StringToUpper:\n");
break;
case 2:
libc_str_lower(libcStr, sizeof(libcStr) - 1);
for(int i = 0; i < sizeof(simdStr)-1; i += 16) {
U128 src;
src.m128= * (__m128i*) &simdStr[i];
U128 dst = CharToLowerU128(src);
* (__m128i*) &simdStr[i] = dst.m128;
}
printf("StringToUpper:\n");
break;
case 3:
libc_str_upper(libcStr, sizeof(libcStr) - 1);
for(int i = 0; i < sizeof(simdStr)-1; i += 16) {
U128 src;
src.m128= * (__m128i*) &simdStr[i];
U128 dst = CharToUpperU128(src);
* (__m128i*) &simdStr[i] = dst.m128;
}
printf("StringToUpper:\n");
break;
default:
assert(0);
break;
}
if(memcmp(libcStr, simdStr, sizeof(libcStr)) != 0) {
errorCount += 1;
printf("\n%s(%d) : FAIL\n", __FILE__, __LINE__);
printf("Expected:\n");
hexDump(libcStr, sizeof(libcStr)-1);
printf("Actual:\n");
hexDump(simdStr, sizeof(simdStr)-1);
} else {
hexDump(simdStr, sizeof(simdStr)-1);
}
}
#endif
}
return errorCount == 0 ? EXIT_SUCCESS : EXIT_FAILURE;
}
#endif
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment