Skip to content

Commit

Permalink
Infrastructure for developing and testing UTF-8 SIMD decode
Browse files Browse the repository at this point in the history
  • Loading branch information
kovidgoyal committed Nov 17, 2023
1 parent 0d92b27 commit 4c0314a
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 54 deletions.
63 changes: 63 additions & 0 deletions kitty/simd-string-impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright (C) 2023 Kovid Goyal <kovid at kovidgoyal.net>
*
* Distributed under terms of the GPL3 license.
*/

#ifndef BITS
#define BITS 128
#endif

#ifdef __clang__
_Pragma("clang diagnostic push") _Pragma("clang diagnostic ignored \"-Wbitwise-instead-of-logical\"")
#endif
#include <simde/x86/avx2.h>
#ifdef __clang__
_Pragma("clang diagnostic pop")
#endif


#if BITS == 128
#define FUNC(name) name##_##128
#define integer_t __m128i
#define set1_epi8 simde_mm_set1_epi8
#define load_unaligned simde_mm_loadu_si128
#define cmpeq_epi8 simde_mm_cmpeq_epi8
#define or_si simde_mm_or_si128
#define movemask_epi8 simde_mm_movemask_epi8
#else
#define FUNC(name) name##_##256
#define integer_t __m256i
#define set1_epi8 simde_mm256_set1_epi8
#define load_unaligned simde_mm256_loadu_si256
#define cmpeq_epi8 simde_mm256_cmpeq_epi8
#define or_si simde_mm256_or_si256
#define movemask_epi8 simde_mm256_movemask_epi8
#endif

static inline const uint8_t*
FUNC(find_either_of_two_bytes)(const uint8_t *haystack, const size_t sz, const uint8_t a, const uint8_t b) {
integer_t a_vec = set1_epi8(a), b_vec = set1_epi8(b);
for (const uint8_t* limit = haystack + sz; haystack < limit; haystack += sizeof(integer_t)) {
const integer_t chunk = load_unaligned((integer_t*)haystack);
const integer_t a_cmp = cmpeq_epi8(chunk, a_vec);
const integer_t b_cmp = cmpeq_epi8(chunk, b_vec);
const integer_t matches = or_si(a_cmp, b_cmp);
const int mask = movemask_epi8(matches);
if (mask != 0) {
size_t pos = __builtin_ctz(mask);
const uint8_t *ans = haystack + pos;
if (ans < limit) return ans;
}
}
return NULL;
}


#undef FUNC
#undef integer_t
#undef set1_epi8
#undef load_unaligned
#undef cmpeq_epi8
#undef or_si
#undef movemask_epi8
120 changes: 66 additions & 54 deletions kitty/simd-string.c
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,10 @@
#include "data-types.h"
#include "charsets.h"
#include "simd-string.h"
#ifdef __clang__
_Pragma("clang diagnostic push") _Pragma("clang diagnostic ignored \"-Wbitwise-instead-of-logical\"")
#endif
#include <simde/x86/avx2.h>
#ifdef __clang__
_Pragma("clang diagnostic pop")
#endif

#include "simd-string-impl.h"
#undef BITS
#define BITS 256
#include "simd-string-impl.h"
static bool has_sse4_2 = false, has_avx2 = false;

// find_either_of_two_bytes {{{
Expand All @@ -27,49 +23,6 @@ find_either_of_two_bytes_scalar(const uint8_t *haystack, const size_t sz, const
}
return NULL;
}
#undef SHIFT_OP

#define _mm128_set1_epi8 _mm_set1_epi8
#define _mm128_load_si128 _mm_load_si128
#define _mm128_loadu_si128 _mm_loadu_si128
#define _mm128_cmpeq_epi8 _mm_cmpeq_epi8
#define _mm128_or_si128 _mm_or_si128
#define _mm128_movemask_epi8 _mm_movemask_epi8
#define _mm128_cmpgt_epi8 _mm_cmpgt_epi8
#define _mm128_and_si128 _mm_and_si128

#define start_simd2(bits) \
__m##bits##i a_vec = _mm##bits##_set1_epi8(a); \
__m##bits##i b_vec = _mm##bits##_set1_epi8(b); \
for (const uint8_t* limit = haystack + sz; haystack < limit; haystack += sizeof(__m##bits##i))

#define end_simd2 \
if (mask != 0) { \
size_t pos = __builtin_ctz(mask); \
if (haystack + pos < limit) return haystack + pos; \
}

#define either_of_two(bits) \
start_simd2(bits) { \
__m##bits##i chunk = _mm##bits##_loadu_si##bits((__m##bits##i*)(haystack)); \
__m##bits##i a_cmp = _mm##bits##_cmpeq_epi8(chunk, a_vec); \
__m##bits##i b_cmp = _mm##bits##_cmpeq_epi8(chunk, b_vec); \
__m##bits##i matches = _mm##bits##_or_si##bits(a_cmp, b_cmp); \
const int mask = _mm##bits##_movemask_epi8(matches); \
end_simd2; \
} return NULL;

static const uint8_t*
find_either_of_two_bytes_sse4_2(const uint8_t *haystack, size_t sz, const uint8_t a, const uint8_t b) {
either_of_two(128);
}


static const uint8_t*
find_either_of_two_bytes_avx2(const uint8_t *haystack, size_t sz, const uint8_t a, const uint8_t b) {
either_of_two(256);
}


static const uint8_t* (*find_either_of_two_bytes_impl)(const uint8_t*, const size_t, const uint8_t, const uint8_t) = find_either_of_two_bytes_scalar;

Expand Down Expand Up @@ -113,15 +66,73 @@ utf8_decode_to_sentinel_scalar(UTF8Decoder *d, const uint8_t *src, const size_t
return num_consumed;
}

static unsigned
utf8_decode_to_sentinel_sse4_2(UTF8Decoder *d, const uint8_t *src, const size_t src_sz, const uint8_t sentinel) {
(void)d; (void)src; (void)src_sz; (void)sentinel;
return 0;
}

static unsigned (*utf8_decode_to_sentinel_impl)(UTF8Decoder *d, const uint8_t *src, const size_t src_sz, const uint8_t sentinel) = utf8_decode_to_sentinel_scalar;

unsigned
utf8_decode_to_sentinel(UTF8Decoder *d, const uint8_t *src, const size_t src_sz, const uint8_t sentinel) {
return utf8_decode_to_sentinel_scalar(d, src, src_sz, sentinel);
return utf8_decode_to_sentinel_impl(d, src, src_sz, sentinel);
}

// }}}

// Boilerplate {{{
static void
test_control_byte_callback(void *l, uint8_t ch) {
if (!PyErr_Occurred()) {
RAII_PyObject(c, PyLong_FromUnsignedLong((unsigned long)ch));
if (c) PyList_Append((PyObject*)l, c);
}
}

static void
test_output_chars_callback(void *l, const uint32_t *chars, unsigned sz) {
if (!PyErr_Occurred()) {
RAII_PyObject(c, PyUnicode_FromKindAndData(PyUnicode_4BYTE_KIND, chars, (Py_ssize_t)sz));
if (c) PyList_Append((PyObject*)l, c);
}
}

static PyObject*
test_utf8_decode_to_sentinel(PyObject *self UNUSED, PyObject *args) {
const uint8_t *src; Py_ssize_t src_sz;
int which_function = 0;
static UTF8Decoder d = {0};
unsigned char sentinel = 0x1b;
if (!PyArg_ParseTuple(args, "s#|iB", &src, &src_sz, &which_function, &sentinel)) return NULL;
RAII_PyObject(ans, PyList_New(0));
d.callback_data = ans;
d.control_byte_callback = test_control_byte_callback;
d.output_chars_callback = test_output_chars_callback;
unsigned long consumed;
switch(which_function) {
case -1:
zero_at_ptr(&d); Py_RETURN_NONE;
case 1:
consumed = utf8_decode_to_sentinel_scalar(&d, src, src_sz, sentinel); break;
case 2:
consumed = utf8_decode_to_sentinel_sse4_2(&d, src, src_sz, sentinel); break;
default:
consumed = utf8_decode_to_sentinel(&d, src, src_sz, sentinel); break;
}
return Py_BuildValue("kO", consumed, ans);
}
// }}}

static PyMethodDef module_methods[] = {
METHODB(test_utf8_decode_to_sentinel, METH_VARARGS),
{NULL, NULL, 0, NULL} /* Sentinel */
};

bool
init_simd(void *x) {
PyObject *module = (PyObject*)x;
if (PyModule_AddFunctions(module, module_methods) != 0) return false;
#define A(x, val) { Py_INCREF(Py_##val); if (0 != PyModule_AddObject(module, #x, Py_##val)) return false; }
#ifdef __APPLE__
#ifdef __arm64__
Expand All @@ -141,13 +152,14 @@ init_simd(void *x) {
#endif
if (has_avx2) {
A(has_avx2, True);
find_either_of_two_bytes_impl = find_either_of_two_bytes_avx2;
find_either_of_two_bytes_impl = find_either_of_two_bytes_256;
} else {
A(has_avx2, False);
}
if (has_sse4_2) {
A(has_sse4_2, True);
if (find_either_of_two_bytes_impl == find_either_of_two_bytes_scalar) find_either_of_two_bytes_impl = find_either_of_two_bytes_sse4_2;
if (find_either_of_two_bytes_impl == find_either_of_two_bytes_scalar) find_either_of_two_bytes_impl = find_either_of_two_bytes_128;
/* if (utf8_decode_to_sentinel_impl == utf8_decode_to_sentinel_scalar) utf8_decode_to_sentinel_impl = utf8_decode_to_sentinel_sse4_2; */
} else {
A(has_sse4_2, False);
}
Expand Down

0 comments on commit 4c0314a

Please sign in to comment.