HashTree<TKey, TValue>
using System.Collections.Generic;
using System.Runtime.CompilerServices;
namespace Stashbox.Utils
{
internal sealed class HashTree<TKey, TValue>
{
private class Node<TK, T>
{
public readonly int storedHash;
public readonly TK storedKey;
public T storedValue;
public Node<TK, T> left;
public Node<TK, T> right;
public int height;
public ExpandableArray<TK, T> collisions;
public Node(TK key, T value, int hash)
{
storedValue = value;
storedKey = key;
storedHash = hash;
height = 1;
}
}
private Node<TKey, TValue> root;
public HashTree()
{
}
public HashTree(TKey key, TValue value)
{
Add(key, value, false);
}
public void Add(TKey key, TValue value, bool byRef = true)
{
root = Add(root, key, byRef ? RuntimeHelpers.GetHashCode(key) : key.GetHashCode(), value, byRef);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public TValue GetOrDefault(TKey key, bool byRef = true)
{
if (root == null)
return default(TValue);
Node<TKey, TValue> node = root;
int num = byRef ? RuntimeHelpers.GetHashCode(key) : key.GetHashCode();
while (node != null && node.storedHash != num) {
node = ((num < node.storedHash) ? node.left : node.right);
}
if (node == null || ((!byRef || (object)key != (object)node.storedKey) && (byRef || !object.Equals(key, node.storedKey)))) {
if (node?.collisions != null)
return node.collisions.GetOrDefault(key, byRef);
return default(TValue);
}
return node.storedValue;
}
private static int CalculateHeight(Node<TKey, TValue> node)
{
if (node.left != null && node.right != null)
return 1 + ((node.left.height > node.right.height) ? node.left.height : node.right.height);
if (node.left == null && node.right == null)
return 1;
return 1 + (node.left?.height ?? node.right.height);
}
private static int GetBalance(Node<TKey, TValue> node)
{
if (node.left != null && node.right != null)
return node.left.height - node.right.height;
if (node.left == null && node.right == null)
return 0;
return node.left?.height ?? (node.right.height * -1);
}
private static Node<TKey, TValue> RotateLeft(Node<TKey, TValue> node)
{
Node<TKey, TValue> right = node.right;
Node<TKey, TValue> left = right.left;
right.left = node;
node.right = left;
right.height = CalculateHeight(right);
node.height = CalculateHeight(node);
return right;
}
private static Node<TKey, TValue> RotateRight(Node<TKey, TValue> node)
{
Node<TKey, TValue> left = node.left;
Node<TKey, TValue> right = left.right;
left.right = node;
node.left = right;
left.height = CalculateHeight(left);
node.height = CalculateHeight(node);
return left;
}
private static Node<TKey, TValue> Add(Node<TKey, TValue> node, TKey key, int hash, TValue value, bool byRef)
{
if (node == null)
return new Node<TKey, TValue>(key, value, hash);
if (node.storedHash == hash) {
CheckCollisions(node, key, value, byRef);
return node;
}
if (node.storedHash > hash)
node.left = Add(node.left, key, hash, value, byRef);
else
node.right = Add(node.right, key, hash, value, byRef);
node.height = CalculateHeight(node);
int balance = GetBalance(node);
if (balance >= 2) {
if (GetBalance(node.left) == -1) {
node.left = RotateLeft(node.left);
node = RotateRight(node);
} else
node = RotateRight(node);
}
if (balance <= -2) {
if (GetBalance(node.right) == 1) {
node.right = RotateRight(node.right);
node = RotateLeft(node);
} else
node = RotateLeft(node);
}
return node;
}
private static void CheckCollisions(Node<TKey, TValue> node, TKey key, TValue value, bool byRef)
{
if ((byRef && (object)key == (object)node.storedKey) || (!byRef && object.Equals(key, node.storedKey)))
node.storedValue = value;
if (node.collisions == null)
node.collisions = new ExpandableArray<TKey, TValue>();
node.collisions.Add(new KeyValuePair<TKey, TValue>(key, value));
}
}
}