Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 230 additions & 0 deletions main.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ limitations under the License.
#include <stdbool.h>
#include <stdint.h>
#include <assert.h>
#if USE_SIMD && (__SSE4__ || __AVX2__)
#include <immintrin.h>
#endif

#ifndef DO_INST
#define DO_INST 0
Expand All @@ -43,6 +46,10 @@ limitations under the License.
static int_fast32_t runlens[4096] = {0};
static int_fast32_t skips[128] = {0};
static int_fast32_t remainders[64] = {0};
static int_fast32_t non_asciis32 = 0;
static int_fast32_t non_asciis16 = 0;
static int_fast32_t non_asciis8 = 0;
static int_fast32_t non_asciis4 = 0;
#endif

#ifndef NDEBUG
Expand All @@ -58,6 +65,32 @@ limitations under the License.
# define likely(x) __builtin_expect((x), 1)
# define unlikely(x) __builtin_expect((x), 0)

typedef union mask {
uint8_t bytes[32];
uint8_t u8;
uint16_t u16;
uint32_t u32;
uint64_t u64;
__uint128_t u128;
#if USE_SIMD && __SSE4__
__m128i m128i;
#endif
#if USE_SIMD && __AVX2__
__m256i m128i;
#endif
} mask;

#if USE_SIMD
const static mask non_ascii = {
.bytes = {
0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
}
};
#endif

#if USE_HEX_TABLE
static const bool lhex[256] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
Expand Down Expand Up @@ -106,6 +139,55 @@ static void print_hit(const unsigned char *buf) {
printf("%.40s\n", buf);
}

#if USE_NON_ASCII4
// perform (0x80 & C) on next 4 bytes at once
// 32-bit operation
static bool non_ascii4(const unsigned char *b) {
uint32_t *bs = (uint32_t*)b;
const uint32_t non_ascii = 0xffffffff;
return (*bs & non_ascii) == non_ascii;
}
#endif

#if USE_SIMD
// perform (0x80 & C) on next 8 bytes at once
// 64-bit operation
static bool non_ascii8(const unsigned char *b) {
uint64_t *bs = (uint64_t*)b;
return (*bs & non_ascii.u64) == non_ascii.u64;
}
#endif

#if USE_SIMD && __SSE4__
// perform (0x80 & C) on next 16 bytes at once
// 128-bit SSE operation
static bool non_ascii16(const unsigned char *b) {
// load 16 bytes as packed 16x8 byte value
__m128i bs = _mm_loadu_si128((const __m128i *)b);
// get the high bit (0x80) from each byte and put it in
// to a the low 16 bites of a 32-bit int as a mask
int high_bits = _mm_movemask_epi8(bs) & 0xffff;
// any high bits mean a non-ascii
// we only care if they are all high.
return high_bits == 0xffff;
}
#endif

#if MAX_STEP >= 32 && USE_SIMD && __AVX2__
// perform (0x80 & C) on next 16 bytes at once 256-bit AVX operation
static bool non_ascii32(const unsigned char *b) {
// load 32 bytes
__m256i bs = _mm256_loadu_si256((const __m256i *)b);
// vpmovmskb is an awesome instruction. It gathers the MSBs from the input
// as packed bytes and returns it as a mask. That's equivalent to &0x80 on
// 32 bytes at once!
int high_bits = _mm256_movemask_epi8(bs);
// any byte with the high bit set cannot be a hex because
// it is outside of the main ascii range;
return high_bits == 0xffffffff;
}
#endif

// At the start of this function, buf is pointing at a non-hex character and the
// goal is to find the next hex character.
static const unsigned char * scan_skip(const unsigned char *buf, const unsigned char *end) {
Expand All @@ -114,13 +196,75 @@ static const unsigned char * scan_skip(const unsigned char *buf, const unsigned
#ifndef NDEBUG
const unsigned char * io = buf;
#endif

#if USE_SIMD
if (unlikely(buf + 41 >= end)) {
return buf;
}

#define MAX_STEP 8

while (skip > 0 && buf + skip + MAX_STEP < end) {
// Runs of 32+ and 16+ non-ascii bytes are not common
// enough to justify the overhead of using these
#if MAX_STEP >= 32
if (non_ascii32(buf+skip)) {
buf += skip + 32;
skip = 40;
INST(non_asciis32++);
continue;
}
#endif
#if MAX_STEP >= 16
if (non_ascii16(buf+skip)) {
buf += skip + 16;
skip = 40;
INST(non_asciis16++);
continue;
}
#endif
#if MAX_STEP >= 8
if (non_ascii8(buf+skip)) {
buf += skip + 8;
skip = 40;
INST(non_asciis8++);
continue;
}
#endif
// this works but hits so few cases that it doesn't give any benefit
#if USE_NON_ASCII4
if (non_ascii4(buf+skip)) {
buf += skip + 4;
skip = 40;
INST(non_asciis4++);
continue;
}
#endif
if (!is_lower_hex(buf+skip)) {
buf += skip;
skip = 40;
continue;
}
skip /= 2;
}

while (skip > 0 && buf + skip < end) {
if (!is_lower_hex(buf+skip)) {
buf += skip;
skip = 40;
continue;
}
skip /= 2;
}
#else
do {
while (buf + skip < end && !is_lower_hex(buf+skip)) {
buf += skip;
skip = 40;
}
skip /= 2;
} while (skip > 1 && buf + skip < end);
#endif
assert(io <= buf);
assert(buf < end);
return buf+1;
Expand Down Expand Up @@ -183,6 +327,11 @@ static const unsigned char * scan_hit_long(const unsigned char *buf, const unsig
// at 50 we know that the current run ends before then and that any runs
// between here and there are too short to care about.

// a sha256 would have ended at buf+24 so buf+25 wouldn't be a hex
if (!is_lower_hex(buf+25) ) {
return scan_skip(buf+25, end);
}

assert(buf +30 < end);

if (!is_lower_hex(buf+30)) {
Expand All @@ -207,6 +356,66 @@ static const unsigned char * scan_hit_long(const unsigned char *buf, const unsig
return scan_hit_short(start, end);
}

#if USE_SIMD && __AVX2__
static int is_hex64(const unsigned char *start) {
uint64_t mask, res;
int pos;

const __m256i b0 = _mm256_loadu_si256((void*)start);
const __m256i b1 = _mm256_loadu_si256((void*)(start+32));

const __m256i rr0 = _mm256_set1_epi8('0'-1);
const __m256i rr1 = _mm256_set1_epi8('9');
const __m256i rr2 = _mm256_set1_epi8('a'-1);
const __m256i rr3 = _mm256_set1_epi8('f');

// x > 0x29
__m256i gz0 = _mm256_cmpgt_epi8(b0, rr0);
__m256i gz1 = _mm256_cmpgt_epi8(b1, rr0);
// .. &! (>0x39)
__m256i le9_0 = _mm256_andnot_si256(_mm256_cmpgt_epi8(b0, rr1), gz0);
__m256i le9_1 = _mm256_andnot_si256(_mm256_cmpgt_epi8(b1, rr1), gz1);
// x > 0x60
__m256i ga0 = _mm256_cmpgt_epi8(b0, rr2);
__m256i ga1 = _mm256_cmpgt_epi8(b1, rr2);
// .. &!(>0x66)
__m256i lef0 = _mm256_andnot_si256(_mm256_cmpgt_epi8(b0, rr3), ga0);
__m256i lef1 = _mm256_andnot_si256(_mm256_cmpgt_epi8(b1, rr3), ga1);

/* Generate bit masks */
unsigned int numeric0 = _mm256_movemask_epi8(le9_0);
unsigned int numeric1 = _mm256_movemask_epi8(le9_1);
unsigned int alpha1 = _mm256_movemask_epi8(lef1);
unsigned int alpha0 = _mm256_movemask_epi8(lef0);

// x > 0x29 && !(x > 0x39) || x > 0x60 && !(x > 0x66)
uint64_t res0 = numeric0 | alpha0;
uint64_t res1 = numeric1 | alpha1;
// [0-31] | [32-63]
res = res0 | (res1 << 32);

// yay little endian! :-/
// 64.............0
// 0x00000080ffffffff
// 0x ffffffff 0-32
// 0x ff 33-40
// 0x 1 41
// 0x000001ffffffffff = mask
// 0x???????????????? & res
// 0x000000ffffffffff = hit!

// bool hit = (res & 0x000001ffffffffff) == 0x000000ffffffffff;

mask = 1;
pos = 0;
while (res & mask) {
pos++;
mask <<= 1;
}
return pos;
}
#endif

// We are at the first hex character. The goal is to determine as efficiently as
// possible if this is a 40 hex character run terminated by a non-hex, something
// shorter, or something longer.
Expand All @@ -220,6 +429,23 @@ static const unsigned char * scan_hit_short(const unsigned char *buf, const unsi
return buf;
}

// Use AVX2 instructions to check 32 bytes + 32 bytes
#if USE_SIMD && __AVX2__
if (likely(buf + 64 < end)) {
int len = is_hex64(buf);
assert(len > 0);
assert(len <= 64);
if (len == 40) {
print_hit(buf);
return scan_skip(buf+len, end);
}
if (len < 64) {
return scan_skip(buf+len, end);
}
return scan_hit_long(buf+40, end);
}
#endif

// We know offset 0 is a hex because that's why we're here.
// We know offset 40 needs to be a non-hex otherwise we're in a 41+ run.
// We know 1-39 all need to be hex characters.
Expand Down Expand Up @@ -372,6 +598,10 @@ int main(int argc, const char *argv[]) {
for (int i = 0; i < arr_len(runlens); i++)
if (runlens[i])
dprintf(2, " [%4d] %10d%s\n", i, runlens[i], i==40 ? " *" : "");
dprintf(2, "non-ascii32: %10d\n", non_asciis32);
dprintf(2, "non-ascii16: %10d\n", non_asciis16);
dprintf(2, "non-ascii8: %10d\n", non_asciis8);
dprintf(2, "non-ascii4: %10d\n", non_asciis4);
#endif

return nread;
Expand Down