Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
271 changes: 238 additions & 33 deletions src/SharpArena/Collections/ArenaDictionary.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using System.Collections;
using System.Buffers;
using System.Text;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
Expand Down Expand Up @@ -242,29 +244,125 @@ public bool ContainsKey(TKey key)
}

/// <summary>
/// Specialized ContainsKey for ArenaUtf16String using ReadOnlySpan{char} to avoid allocations.
/// Specialized ContainsKey for string types using ReadOnlySpan{char} to avoid allocations.
/// Supports cross-encoding for ArenaUtf8String.
/// </summary>
/// <param name="key">The key to locate in the <see cref="ArenaDictionary{TKey, TValue}"/>.</param>
/// <returns><see langword="true"/> if the <see cref="ArenaDictionary{TKey, TValue}"/> contains an element with the key; otherwise, <see langword="false"/>.</returns>
public bool ContainsKey(ReadOnlySpan<char> key)
{
if (typeof(TKey) != typeof(ArenaUtf16String)) return false;
CheckAlive();
if (typeof(TKey) == typeof(ArenaUtf16String))
{
CheckAlive();

uint capacity = (uint)_header->Capacity;
int* buckets = _header->Buckets;
ArenaUtf16String* keys = (ArenaUtf16String*)_header->Keys;
uint mask = capacity - 1;
uint hash = Hashing.HashString(key);
uint index = hash & mask;

while (true)
{
int entryIdxPlusOne = buckets[index];
if (entryIdxPlusOne == 0) return false;
if (keys[entryIdxPlusOne - 1].Equals(key)) return true;
index = (index + 1) & mask;
}
}

uint capacity = (uint)_header->Capacity;
int* buckets = _header->Buckets;
ArenaUtf16String* keys = (ArenaUtf16String*)_header->Keys;
uint mask = capacity - 1;
uint hash = Hashing.HashString(key);
uint index = hash & mask;
if (typeof(TKey) == typeof(ArenaUtf8String))
{
CheckAlive();
int maxBytes = Encoding.UTF8.GetMaxByteCount(key.Length);
byte[]? rented = null;
Span<byte> buffer = maxBytes <= 512 ? stackalloc byte[512] : (rented = ArrayPool<byte>.Shared.Rent(maxBytes));
try
{
int written = Encoding.UTF8.GetBytes(key, buffer);
ReadOnlySpan<byte> byteKey = buffer.Slice(0, written);

uint capacity = (uint)_header->Capacity;
int* buckets = _header->Buckets;
ArenaUtf8String* keys = (ArenaUtf8String*)_header->Keys;
uint mask = capacity - 1;
uint hash = Hashing.HashUtf8(byteKey);
uint index = hash & mask;
while (true)
{
int entryIdxPlusOne = buckets[index];
if (entryIdxPlusOne == 0) return false;
if (keys[entryIdxPlusOne - 1].Equals(byteKey)) return true;
index = (index + 1) & mask;
}
}
finally
{
if (rented != null) ArrayPool<byte>.Shared.Return(rented);
}
}

while (true)
return false;
}

/// <summary>
/// Specialized ContainsKey for ArenaUtf8String using ReadOnlySpan{byte} to avoid allocations.
/// </summary>
/// <param name="key">The key to locate in the <see cref="ArenaDictionary{TKey, TValue}"/>.</param>
/// <returns><see langword="true"/> if the <see cref="ArenaDictionary{TKey, TValue}"/> contains an element with the key; otherwise, <see langword="false"/>.</returns>
public bool ContainsKey(ReadOnlySpan<byte> key)
{
if (typeof(TKey) == typeof(ArenaUtf8String))
{
int entryIdxPlusOne = buckets[index];
if (entryIdxPlusOne == 0) return false;
if (keys[entryIdxPlusOne - 1].Equals(key)) return true;
index = (index + 1) & mask;
CheckAlive();
uint capacity = (uint)_header->Capacity;
int* buckets = _header->Buckets;
ArenaUtf8String* keys = (ArenaUtf8String*)_header->Keys;
uint mask = capacity - 1;
uint hash = Hashing.HashUtf8(key);
uint index = hash & mask;
while (true)
{
int entryIdxPlusOne = buckets[index];
if (entryIdxPlusOne == 0) return false;
if (keys[entryIdxPlusOne - 1].Equals(key)) return true;
index = (index + 1) & mask;
}
}

if (typeof(TKey) == typeof(ArenaUtf16String))
{
CheckAlive();
int maxChars = Encoding.UTF8.GetMaxCharCount(key.Length);
char[]? rented = null;
Span<char> buffer = maxChars <= 512 ? stackalloc char[512] : (rented = ArrayPool<char>.Shared.Rent(maxChars));
try
{
int written = Encoding.UTF8.GetChars(key, buffer);
ReadOnlySpan<char> charKey = buffer.Slice(0, written);

uint capacity = (uint)_header->Capacity;
int* buckets = _header->Buckets;
ArenaUtf16String* keys = (ArenaUtf16String*)_header->Keys;
uint mask = capacity - 1;
uint hash = Hashing.HashString(charKey);
uint index = hash & mask;

while (true)
{
int entryIdxPlusOne = buckets[index];
if (entryIdxPlusOne == 0) return false;
if (keys[entryIdxPlusOne - 1].Equals(charKey)) return true;
index = (index + 1) & mask;
}
}
finally
{
if (rented != null) ArrayPool<char>.Shared.Return(rented);
}
}

return false;
}

/// <summary>
Expand All @@ -288,38 +386,145 @@ public bool TryGetValue(TKey key, out TValue value)
}

/// <summary>
/// Specialized TryGetValue for ArenaUtf16String using ReadOnlySpan{char} to avoid allocations.
/// Specialized TryGetValue for string types using ReadOnlySpan{char} to avoid allocations.
/// Supports cross-encoding for ArenaUtf8String.
/// </summary>
/// <param name="key">The key whose value to get.</param>
/// <param name="value">When this method returns, the value associated with the specified key, if the key is found; otherwise, the default value for the type of the <paramref name="value"/> parameter. This parameter is passed uninitialized.</param>
/// <returns><see langword="true"/> if the <see cref="ArenaDictionary{TKey, TValue}"/> contains an element with the specified key; otherwise, <see langword="false"/>.</returns>
public bool TryGetValue(ReadOnlySpan<char> key, out TValue value)
{
if (typeof(TKey) != typeof(ArenaUtf16String))
if (typeof(TKey) == typeof(ArenaUtf16String))
{
value = default;
return false;
CheckAlive();

uint capacity = (uint)_header->Capacity;
int* buckets = _header->Buckets;
ArenaUtf16String* keys = (ArenaUtf16String*)_header->Keys;
TValue* values = (TValue*)_header->Values;
uint mask = capacity - 1;
uint hash = Hashing.HashString(key);
uint index = hash & mask;

while (true)
{
int entryIdxPlusOne = buckets[index];
if (entryIdxPlusOne == 0) break;
if (keys[entryIdxPlusOne - 1].Equals(key))
{
value = values[entryIdxPlusOne - 1];
return true;
}
index = (index + 1) & mask;
}
}
CheckAlive();

uint capacity = (uint)_header->Capacity;
int* buckets = _header->Buckets;
ArenaUtf16String* keys = (ArenaUtf16String*)_header->Keys;
TValue* values = (TValue*)_header->Values;
uint mask = capacity - 1;
uint hash = Hashing.HashString(key);
uint index = hash & mask;
if (typeof(TKey) == typeof(ArenaUtf8String))
{
CheckAlive();
int maxBytes = Encoding.UTF8.GetMaxByteCount(key.Length);
byte[]? rented = null;
Span<byte> buffer = maxBytes <= 512 ? stackalloc byte[512] : (rented = ArrayPool<byte>.Shared.Rent(maxBytes));
try
{
int written = Encoding.UTF8.GetBytes(key, buffer);
ReadOnlySpan<byte> byteKey = buffer.Slice(0, written);

uint capacity = (uint)_header->Capacity;
int* buckets = _header->Buckets;
ArenaUtf8String* keys = (ArenaUtf8String*)_header->Keys;
TValue* values = (TValue*)_header->Values;
uint mask = capacity - 1;
uint hash = Hashing.HashUtf8(byteKey);
uint index = hash & mask;
while (true)
{
int entryIdxPlusOne = buckets[index];
if (entryIdxPlusOne == 0) break;
if (keys[entryIdxPlusOne - 1].Equals(byteKey))
{
value = values[entryIdxPlusOne - 1];
return true;
}
index = (index + 1) & mask;
}
}
finally
{
if (rented != null) ArrayPool<byte>.Shared.Return(rented);
}
}

while (true)
value = default;
return false;
}

/// <summary>
/// Specialized TryGetValue for ArenaUtf8String using ReadOnlySpan{byte} to avoid allocations.
/// </summary>
/// <param name="key">The key whose value to get.</param>
/// <param name="value">When this method returns, the value associated with the specified key, if the key is found; otherwise, the default value for the type of the <paramref name="value"/> parameter. This parameter is passed uninitialized.</param>
/// <returns><see langword="true"/> if the <see cref="ArenaDictionary{TKey, TValue}"/> contains an element with the specified key; otherwise, <see langword="false"/>.</returns>
public bool TryGetValue(ReadOnlySpan<byte> key, out TValue value)
{
if (typeof(TKey) == typeof(ArenaUtf8String))
{
int entryIdxPlusOne = buckets[index];
if (entryIdxPlusOne == 0) break;
if (keys[entryIdxPlusOne - 1].Equals(key))
CheckAlive();
uint capacity = (uint)_header->Capacity;
int* buckets = _header->Buckets;
ArenaUtf8String* keys = (ArenaUtf8String*)_header->Keys;
TValue* values = (TValue*)_header->Values;
uint mask = capacity - 1;
uint hash = Hashing.HashUtf8(key);
uint index = hash & mask;
while (true)
{
value = values[entryIdxPlusOne - 1];
return true;
int entryIdxPlusOne = buckets[index];
if (entryIdxPlusOne == 0) break;
if (keys[entryIdxPlusOne - 1].Equals(key))
{
value = values[entryIdxPlusOne - 1];
return true;
}
index = (index + 1) & mask;
}
}

if (typeof(TKey) == typeof(ArenaUtf16String))
{
CheckAlive();
int maxChars = Encoding.UTF8.GetMaxCharCount(key.Length);
char[]? rented = null;
Span<char> buffer = maxChars <= 512 ? stackalloc char[512] : (rented = ArrayPool<char>.Shared.Rent(maxChars));
try
{
int written = Encoding.UTF8.GetChars(key, buffer);
ReadOnlySpan<char> charKey = buffer.Slice(0, written);

uint capacity = (uint)_header->Capacity;
int* buckets = _header->Buckets;
ArenaUtf16String* keys = (ArenaUtf16String*)_header->Keys;
TValue* values = (TValue*)_header->Values;
uint mask = capacity - 1;
uint hash = Hashing.HashString(charKey);
uint index = hash & mask;

while (true)
{
int entryIdxPlusOne = buckets[index];
if (entryIdxPlusOne == 0) break;
if (keys[entryIdxPlusOne - 1].Equals(charKey))
{
value = values[entryIdxPlusOne - 1];
return true;
}
index = (index + 1) & mask;
}
}
finally
{
if (rented != null) ArrayPool<char>.Shared.Return(rented);
}
index = (index + 1) & mask;
}

value = default;
Expand Down
3 changes: 1 addition & 2 deletions src/SharpArena/Collections/Hashing.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ public static uint Hash<T>(T value) where T : unmanaged
return (uint)Unsafe.As<T, ArenaUtf8String>(ref value).GetHashCode();
}

var span = new ReadOnlySpan<byte>(&value, sizeof(T));
return (uint)XxHash3.HashToUInt64(span);
return (uint)EqualityComparer<T>.Default.GetHashCode(value);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
Expand Down
Loading
Loading