Skip to content

Instantly share code, notes, and snippets.

@easyaspi314
Created February 6, 2019 06:06
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save easyaspi314/9d31e5c0f9cead66aba2ede248b74d64 to your computer and use it in GitHub Desktop.
Save easyaspi314/9d31e5c0f9cead66aba2ede248b74d64 to your computer and use it in GitHub Desktop.
SIMD functions to apply toupper/tolower to each character in a string
/* Created by easyaspi314. Released into the public domain. */
#include <ctype.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#ifdef __SSE2__
#include <immintrin.h>
/* A SIMD function for SSE2 which changes all uppercase ASCII digits
* to lowercase. */
void StringToLower(char *str)
{
size_t len = strlen(str);
const __m128i asciiA = _mm_set1_epi8('A');
const __m128i asciiZ = _mm_set1_epi8('Z' + 1);
const __m128i diff = _mm_set1_epi8('a' - 'A');
while (len >= 16) {
__m128i inp = _mm_loadu_si128((__m128i*)str);
/* > 'A': 0xff, < 'A': 0x00 */
__m128i greaterThanA = _mm_cmpgt_epi8(inp, asciiA);
/* <= 'Z': 0xff, > 'Z': 0x00 */
__m128i lessEqualZ = _mm_cmplt_epi8(inp, asciiZ);
/* 'Z' >= x >= 'A': 0xFF, else 0x00 */
__m128i mask = _mm_and_si128(greaterThanA, lessEqualZ);
/* 'Z' >= x >= 'A': 'a' - 'A', else 0x00 */
__m128i toAdd = _mm_and_si128(mask, diff);
/* add to change to lowercase */
__m128i added = _mm_add_epi8(inp, toAdd);
_mm_storeu_si128((__m128i *)str, added);
len -= 16;
str += 16;
}
while (len-- > 0) {
*str = tolower(*str);
++str;
}
}
/* Same, but to uppercase. */
void StringToUpper(char *str)
{
size_t len = strlen(str);
const __m128i asciia = _mm_set1_epi8('a');
const __m128i asciiz = _mm_set1_epi8('z' + 1);
const __m128i diff = _mm_set1_epi8('a' - 'A');
while (len >= 16) {
__m128i inp = _mm_loadu_si128((__m128i*)str);
/* > 'a': 0xff, < 'a': 0x00 */
__m128i greaterThana = _mm_cmpgt_epi8(inp, asciia);
/* <= 'z': 0xff, > 'z': 0x00 */
__m128i lessEqualz = _mm_cmplt_epi8(inp, asciiz);
/* 'z' >= x >= 'a': 0xFF, else 0x00 */
__m128i mask = _mm_and_si128(greaterThana, lessEqualz);
/* 'z' >= x >= 'a': 'a' - 'A', else 0x00 */
__m128i toSub = _mm_and_si128(mask, diff);
/* subtract to change to uppercase */
__m128i added = _mm_sub_epi8(inp, toSub);
_mm_storeu_si128((__m128i *)str, added);
len -= 16;
str += 16;
}
while (len-- > 0) {
*str = toupper(*str);
++str;
}
}
#elif defined(__ARM_NEON) || defined(__ARM_NEON__)
#include <arm_neon.h>
/* Literally the exact same code as above, but for NEON. */
void StringToLower(char *str)
{
size_t len = strlen(str);
const uint8x16_t asciiA = vdupq_n_u8('A');
const uint8x16_t asciiZ = vdupq_n_u8('Z' + 1);
const uint8x16_t diff = vdupq_n_u8('a' - 'A');
while (len >= 16) {
uint8x16_t inp = vld1q_u8((uint8_t *)str);
uint8x16_t greaterThanA = vcgtq_u8(inp, asciiA);
uint8x16_t lessEqualZ = vcltq_u8(inp, asciiZ);
uint8x16_t mask = vandq_u8(greaterThanA, lessEqualZ);
uint8x16_t toAdd = vandq_u8(mask, diff);
uint8x16_t added = vaddq_u8(inp, toAdd);
vst1q_u8((uint8_t *)str, added);
len -= 16;
str += 16;
}
while (len-- > 0) {
*str = tolower(*str);
++str;
}
}
/* Literally the exact same code as above, but for NEON. */
void StringToUpper(char *str)
{
size_t len = strlen(str);
const uint8x16_t asciia = vdupq_n_u8('a');
const uint8x16_t asciiz = vdupq_n_u8('z' + 1);
const uint8x16_t diff = vdupq_n_u8('a' - 'A');
while (len >= 16) {
uint8x16_t inp = vld1q_u8((uint8_t *)str);
uint8x16_t greaterThana = vcgtq_u8(inp, asciia);
uint8x16_t lessEqualz = vcltq_u8(inp, asciiz);
uint8x16_t mask = vandq_u8(greaterThana, lessEqualz);
uint8x16_t toSub = vandq_u8(mask, diff);
uint8x16_t added = vsubq_u8(inp, toSub);
vst1q_u8((uint8_t *)str, added);
len -= 16;
str += 16;
}
while (len-- > 0) {
*str = toupper(*str);
++str;
}
}
#else
/* Just go scalar. */
void StringToLower(char *str)
{
size_t len = strlen(str);
while (len-- > 0) {
*str = tolower(*str);
++str;
}
}
void StringToUpper(char *str)
{
size_t len = strlen(str);
while (len-- > 0) {
*str = toupper(*str);
++str;
}
}
#endif
#ifdef TEST
#include <assert.h>
int main()
{
char str[] = "Hello world 12345 HI AZaz";
StringToLower(str);
assert(!strcmp(str, "hello world 12345 hi azaz"));
puts(str);
StringToUpper(str);
assert(!strcmp(str, "HELLO WORLD 12345 HI AZAZ"));
puts(str);
}
#endif
@bb1950328
Copy link

this doesn't convert A to a.
changing _mm_set1_epi8('A') to _mm_set1_epi8('A' - 1) fixes the problem. I haven't tested the other three functions yet, but I think the mistake is also there.

@25077667
Copy link

25077667 commented Sep 1, 2023

@bb1950328 Yeah, you are correct. The GitHub Copilot learned this code, it made me struggle for an hour.

@t-mat
Copy link

t-mat commented Dec 18, 2023

Hi, thanks for the useful code example.

It seems TEST does not work as intended. Since strlen(str) == 26, the last part ( HI AZaz) is processed with scalar loop.
If it passes str[] = "Hello world 12345 hi ABXYZ abxyz" or longer string, TEST will fail properly.

As for actual functions, as @bb1950328 suggested, the following patch may work:

patch
void StringToLower(char *str)
{
    size_t len = strlen(str);
-   const __m128i asciiA = _mm_set1_epi8('A');
+   const __m128i asciiA = _mm_set1_epi8('A' - 1);
    const __m128i asciiZ = _mm_set1_epi8('Z' + 1);
...

void StringToUpper(char *str)
{
    size_t len = strlen(str);
-   const __m128i asciia = _mm_set1_epi8('a');
+   const __m128i asciia = _mm_set1_epi8('a' - 1);
    const __m128i asciiz = _mm_set1_epi8('z' + 1);
...

void StringToLower(char *str)
{
    size_t len = strlen(str);
-   const uint8x16_t asciiA = vdupq_n_u8('A');
+   const uint8x16_t asciiA = vdupq_n_u8('A' - 1);
    const uint8x16_t asciiZ = vdupq_n_u8('Z' + 1);
...

void StringToUpper(char *str)
{
    size_t len = strlen(str);
-   const uint8x16_t asciia = vdupq_n_u8('a');
+   const uint8x16_t asciia = vdupq_n_u8('a' - 1);
    const uint8x16_t asciiz = vdupq_n_u8('z' + 1);
...

int main()
{
-   char str[] = "Hello world 12345 HI AZaz";
+   char str[] = "Hello world 12345 HI ABCDEFGHIJKLMNOPQRSTUVWXYZ abcdefghijklmnopqrstuvwxyz #PAD#";
    StringToLower(str);
-   assert(!strcmp(str, "hello world 12345 hi azaz"));
+   assert(!strcmp(str, "hello world 12345 hi abcdefghijklmnopqrstuvwxyz abcdefghijklmnopqrstuvwxyz #pad#"));
    puts(str);

    StringToUpper(str);
-   assert(!strcmp(str, "HELLO WORLD 12345 HI AZAZ"));
+   assert(!strcmp(str, "HELLO WORLD 12345 HI ABCDEFGHIJKLMNOPQRSTUVWXYZ ABCDEFGHIJKLMNOPQRSTUVWXYZ #PAD#"));
    puts(str);
...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment