httparse/simd/
avx2.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
use crate::iter::Bytes;

#[inline]
#[target_feature(enable = "avx2")]
pub unsafe fn match_uri_vectored(bytes: &mut Bytes) {
    while bytes.as_ref().len() >= 32 {

        let advance = match_url_char_32_avx(bytes.as_ref());

        bytes.advance(advance);

        if advance != 32 {
            return;
        }
    }
    // NOTE: use SWAR for <32B, more efficient than falling back to SSE4.2
    super::swar::match_uri_vectored(bytes)
}

#[inline(always)]
#[allow(non_snake_case, overflowing_literals)]
#[allow(unused)]
unsafe fn match_url_char_32_avx(buf: &[u8]) -> usize {
    // NOTE: This check might be not necessary since this function is only used in
    // `match_uri_vectored` where buffer overflow is taken care of.
    debug_assert!(buf.len() >= 32);

    #[cfg(target_arch = "x86")]
    use core::arch::x86::*;
    #[cfg(target_arch = "x86_64")]
    use core::arch::x86_64::*;

    // pointer to buffer
    let ptr = buf.as_ptr();

    // %x21-%x7e %x80-%xff
    //
    // Character ranges allowed by this function, can also be interpreted as:
    // 33 =< (x != 127) =< 255
    //
    // Create a vector full of DEL (0x7f) characters.
    let DEL: __m256i = _mm256_set1_epi8(0x7f);
    // Create a vector full of exclamation mark (!) (0x21) characters.
    // Used as lower threshold, characters in URLs cannot be smaller than this.
    let LOW: __m256i = _mm256_set1_epi8(0x21);

    // Load a chunk of 32 bytes from `ptr` as a vector.
    // We can check 32 bytes in parallel at most with AVX2 since
    // YMM registers can only have 256 bits most.
    let dat = _mm256_lddqu_si256(ptr as *const _);

    // unsigned comparison dat >= LOW
    //
    // `_mm256_max_epu8` creates a new vector by comparing vectors `dat` and `LOW`
    // and picks the max. values from each for all indices.
    // So if a byte in `dat` is <= 32, it'll be represented as 33
    // which is the smallest valid character.
    //
    // Then, we compare the new vector with `dat` for equality.
    //
    // `_mm256_cmpeq_epi8` returns a new vector where;
    // * matching bytes are set to 0xFF (all bits set),
    // * nonmatching bytes are set to 0 (no bits set).
    let low = _mm256_cmpeq_epi8(_mm256_max_epu8(dat, LOW), dat);
    // Similar to what we did before, but now invalid characters are set to 0xFF.
    let del = _mm256_cmpeq_epi8(dat, DEL);

    // We glue the both comparisons via `_mm256_andnot_si256`.
    //
    // Since the representation of truthiness differ in these comparisons,
    // we are in need of bitwise NOT to convert valid characters of `del`.
    let bit = _mm256_andnot_si256(del, low);
    // This creates a bitmask from the most significant bit of each byte.
    // Simply, we're converting a vector value to scalar value here.
    let res = _mm256_movemask_epi8(bit) as u32;

    // Count trailing zeros to find the first encountered invalid character.
    // Bitwise NOT is required once again to flip truthiness.
    // TODO: use .trailing_ones() once MSRV >= 1.46
    (!res).trailing_zeros() as usize
}

#[target_feature(enable = "avx2")]
pub unsafe fn match_header_value_vectored(bytes: &mut Bytes) {
    while bytes.as_ref().len() >= 32 {
        let advance = match_header_value_char_32_avx(bytes.as_ref());
        bytes.advance(advance);

        if advance != 32 {
            return;
        }
    }
    // NOTE: use SWAR for <32B, more efficient than falling back to SSE4.2
    super::swar::match_header_value_vectored(bytes)
}

#[inline(always)]
#[allow(non_snake_case)]
#[allow(unused)]
unsafe fn match_header_value_char_32_avx(buf: &[u8]) -> usize {
    debug_assert!(buf.len() >= 32);

    #[cfg(target_arch = "x86")]
    use core::arch::x86::*;
    #[cfg(target_arch = "x86_64")]
    use core::arch::x86_64::*;

    let ptr = buf.as_ptr();

    // %x09 %x20-%x7e %x80-%xff
    // Create a vector full of horizontal tab (\t) (0x09) characters.
    let TAB: __m256i = _mm256_set1_epi8(0x09);
    // Create a vector full of DEL (0x7f) characters.
    let DEL: __m256i = _mm256_set1_epi8(0x7f);
    // Create a vector full of space (0x20) characters.
    let LOW: __m256i = _mm256_set1_epi8(0x20);

    // Load a chunk of 32 bytes from `ptr` as a vector.
    let dat = _mm256_lddqu_si256(ptr as *const _);

    // unsigned comparison dat >= LOW
    //
    // Same as what we do in `match_url_char_32_avx`.
    // This time the lower threshold is set to space character though.
    let low = _mm256_cmpeq_epi8(_mm256_max_epu8(dat, LOW), dat);
    // Check if `dat` includes `TAB` characters.
    let tab = _mm256_cmpeq_epi8(dat, TAB);
    // Check if `dat` includes `DEL` characters.
    let del = _mm256_cmpeq_epi8(dat, DEL);

    // Combine all comparisons together, notice that we're also using OR
    // to connect `low` and `tab` but flip bits of `del`.
    //
    // In the end, this is simply:
    // ~del & (low | tab)
    let bit = _mm256_andnot_si256(del, _mm256_or_si256(low, tab));
    // This creates a bitmask from the most significant bit of each byte.
    // Creates a scalar value from vector value.
    let res = _mm256_movemask_epi8(bit) as u32;

    // Count trailing zeros to find the first encountered invalid character.
    // Bitwise NOT is required once again to flip truthiness.
    // TODO: use .trailing_ones() once MSRV >= 1.46
    (!res).trailing_zeros() as usize
}

#[test]
fn avx2_code_matches_uri_chars_table() {
    if !is_x86_feature_detected!("avx2") {
        return;
    }

    #[allow(clippy::undocumented_unsafe_blocks)]
    unsafe {
        assert!(byte_is_allowed(b'_', match_uri_vectored));

        for (b, allowed) in crate::URI_MAP.iter().cloned().enumerate() {
            assert_eq!(
                byte_is_allowed(b as u8, match_uri_vectored), allowed,
                "byte_is_allowed({:?}) should be {:?}", b, allowed,
            );
        }
    }
}

#[test]
fn avx2_code_matches_header_value_chars_table() {
    if !is_x86_feature_detected!("avx2") {
        return;
    }

    #[allow(clippy::undocumented_unsafe_blocks)]
    unsafe {
        assert!(byte_is_allowed(b'_', match_header_value_vectored));

        for (b, allowed) in crate::HEADER_VALUE_MAP.iter().cloned().enumerate() {
            assert_eq!(
                byte_is_allowed(b as u8, match_header_value_vectored), allowed,
                "byte_is_allowed({:?}) should be {:?}", b, allowed,
            );
        }
    }
}

#[cfg(test)]
unsafe fn byte_is_allowed(byte: u8, f: unsafe fn(bytes: &mut Bytes<'_>)) -> bool {
    let slice = [
        b'_', b'_', b'_', b'_',
        b'_', b'_', b'_', b'_',
        b'_', b'_', b'_', b'_',
        b'_', b'_', b'_', b'_',
        b'_', b'_', b'_', b'_',
        b'_', b'_', b'_', b'_',
        b'_', b'_', byte, b'_',
        b'_', b'_', b'_', b'_',
    ];
    let mut bytes = Bytes::new(&slice);

    f(&mut bytes);

    match bytes.pos() {
        32 => true,
        26 => false,
        _ => unreachable!(),
    }
}