diff --git a/src/main/java/org/apache/commons/collections4/trie/AbstractBitwiseTrie.java b/src/main/java/org/apache/commons/collections4/trie/AbstractBitwiseTrie.java index 5e7d69a2d2..d036313222 100644 --- a/src/main/java/org/apache/commons/collections4/trie/AbstractBitwiseTrie.java +++ b/src/main/java/org/apache/commons/collections4/trie/AbstractBitwiseTrie.java @@ -169,9 +169,9 @@ final K castKey(final Object key) { } /** - * A utility method for calling {@link KeyAnalyzer#compare(Object, Object)} + * A null-safe utility method for calling {@link KeyAnalyzer#compare(Object, Object)} */ - final boolean compareKeys(final K key, final K other) { + final boolean keysAreEqual(final K key, final K other) { if (key == null) { return other == null; } diff --git a/src/main/java/org/apache/commons/collections4/trie/AbstractPatriciaTrie.java b/src/main/java/org/apache/commons/collections4/trie/AbstractPatriciaTrie.java index a8a367c2ca..472bbe17d4 100644 --- a/src/main/java/org/apache/commons/collections4/trie/AbstractPatriciaTrie.java +++ b/src/main/java/org/apache/commons/collections4/trie/AbstractPatriciaTrie.java @@ -1331,24 +1331,6 @@ TrieEntry addEntry(final TrieEntry entry, final int lengthInBits) { * than or equal to the given key, or null if there is no such key. */ TrieEntry ceilingEntry(final K key) { - // Basically: - // Follow the steps of adding an entry, but instead... - // - // - If we ever encounter a situation where we found an equal - // key, we return it immediately. - // - // - If we hit an empty root, return the first iterable item. - // - // - If we have to add a new item, we temporarily add it, - // find the successor to it, then remove the added item. - // - // These steps ensure that the returned value is either the - // entry for the key itself, or the first entry directly after - // the key. - - // TODO: Cleanup so that we don't actually have to add/remove from the - // tree. (We do it here because there are other well-defined - // functions to perform the search.) final int lengthInBits = lengthInBits(key); if (lengthInBits == 0) { @@ -1359,19 +1341,31 @@ TrieEntry ceilingEntry(final K key) { } final TrieEntry found = getNearestEntryForKey(key, lengthInBits); - if (compareKeys(key, found.key)) { + if (keysAreEqual(key, found.key)) { return found; } final int bitIndex = bitIndex(key, found.key); if (KeyAnalyzer.isValidBitIndex(bitIndex)) { - final TrieEntry added = new TrieEntry<>(key, null, bitIndex); - addEntry(added, lengthInBits); - incrementSize(); // must increment because remove will decrement - final TrieEntry ceil = nextEntry(added); - removeEntry(added); - modCount -= 2; // we didn't really modify it. - return ceil; + if (!isBitSet(key, bitIndex, lengthInBits)) { + // search key < found.key + // found is a ceiling candidate, walk backward to find the smallest entry still >= key + TrieEntry ceiling = found; + TrieEntry prev = previousEntry(found); + while (prev != null && !prev.isEmpty() && getKeyAnalyzer().compare(key, prev.key) <= 0) { + ceiling = prev; + prev = previousEntry(prev); + } + return ceiling; + } else { + // search key > found.key + // walk forward to find the first entry.key > key + TrieEntry next = nextEntry(found); + while (next != null && getKeyAnalyzer().compare(key, next.key) > 0) { + next = nextEntry(next); + } + return next; + } } if (KeyAnalyzer.isNullBitKey(bitIndex)) { if (!root.isEmpty()) { @@ -1416,7 +1410,7 @@ public boolean containsKey(final Object k) { final K key = castKey(k); final int lengthInBits = lengthInBits(key); final TrieEntry entry = getNearestEntryForKey(key, lengthInBits); - return !entry.isEmpty() && compareKeys(key, entry.key); + return !entry.isEmpty() && keysAreEqual(key, entry.key); } /** @@ -1463,9 +1457,6 @@ public K firstKey() { * less than or equal to the given key, or null if there is no such key. */ TrieEntry floorEntry(final K key) { - // TODO: Cleanup so that we don't actually have to add/remove from the - // tree. (We do it here because there are other well-defined - // functions to perform the search.) final int lengthInBits = lengthInBits(key); if (lengthInBits == 0) { @@ -1476,19 +1467,30 @@ TrieEntry floorEntry(final K key) { } final TrieEntry found = getNearestEntryForKey(key, lengthInBits); - if (compareKeys(key, found.key)) { + if (keysAreEqual(key, found.key)) { return found; } final int bitIndex = bitIndex(key, found.key); if (KeyAnalyzer.isValidBitIndex(bitIndex)) { - final TrieEntry added = new TrieEntry<>(key, null, bitIndex); - addEntry(added, lengthInBits); - incrementSize(); // must increment because remove will decrement - final TrieEntry floor = previousEntry(added); - removeEntry(added); - modCount -= 2; // we didn't really modify it. - return floor; + if (isBitSet(key, bitIndex, lengthInBits)) { + TrieEntry floor = found; + TrieEntry next = nextEntry(found); + while (next != null && getKeyAnalyzer().compare(key, next.key) >= 0) { + floor = next; + next = nextEntry(next); + } + return floor; + } else { + TrieEntry prev = previousEntry(found); + while (prev != null && !prev.isEmpty() && getKeyAnalyzer().compare(key, prev.key) < 0) { + prev = previousEntry(prev); + } + if (prev == null || prev.isEmpty()) { + return null; + } + return prev; + } } if (KeyAnalyzer.isNullBitKey(bitIndex)) { if (!root.isEmpty()) { @@ -1561,7 +1563,7 @@ TrieEntry getEntry(final Object k) { final int lengthInBits = lengthInBits(key); final TrieEntry entry = getNearestEntryForKey(key, lengthInBits); - return !entry.isEmpty() && compareKeys(key, entry.key) ? entry : null; + return !entry.isEmpty() && keysAreEqual(key, entry.key) ? entry : null; } /** @@ -1632,9 +1634,6 @@ public SortedMap headMap(final K toKey) { * or null if no such entry exists. */ TrieEntry higherEntry(final K key) { - // TODO: Cleanup so that we don't actually have to add/remove from the - // tree. (We do it here because there are other well-defined - // functions to perform the search.) final int lengthInBits = lengthInBits(key); if (lengthInBits == 0) { @@ -1651,19 +1650,27 @@ TrieEntry higherEntry(final K key) { } final TrieEntry found = getNearestEntryForKey(key, lengthInBits); - if (compareKeys(key, found.key)) { + if (keysAreEqual(key, found.key)) { return nextEntry(found); } final int bitIndex = bitIndex(key, found.key); if (KeyAnalyzer.isValidBitIndex(bitIndex)) { - final TrieEntry added = new TrieEntry<>(key, null, bitIndex); - addEntry(added, lengthInBits); - incrementSize(); // must increment because remove will decrement - final TrieEntry ceil = nextEntry(added); - removeEntry(added); - modCount -= 2; // we didn't really modify it. - return ceil; + if (!isBitSet(key, bitIndex, lengthInBits)) { + TrieEntry ceiling = found; + TrieEntry prev = previousEntry(found); + while (prev != null && !prev.isEmpty() && getKeyAnalyzer().compare(key, prev.key) <= 0) { + ceiling = prev; + prev = previousEntry(prev); + } + return ceiling; + } else { + TrieEntry next = nextEntry(found); + while (next != null && getKeyAnalyzer().compare(key, next.key) > 0) { + next = nextEntry(next); + } + return next; + } } if (KeyAnalyzer.isNullBitKey(bitIndex)) { if (!root.isEmpty()) { @@ -1729,23 +1736,6 @@ public K lastKey() { * strictly less than the given key, or null if there is no such key. */ TrieEntry lowerEntry(final K key) { - // Basically: - // Follow the steps of adding an entry, but instead... - // - // - If we ever encounter a situation where we found an equal - // key, we return it's previousEntry immediately. - // - // - If we hit root (empty or not), return null. - // - // - If we have to add a new item, we temporarily add it, - // find the previousEntry to it, then remove the added item. - // - // These steps ensure that the returned value is always just before - // the key or null (if there was nothing before it). - - // TODO: Cleanup so that we don't actually have to add/remove from the - // tree. (We do it here because there are other well-defined - // functions to perform the search.) final int lengthInBits = lengthInBits(key); if (lengthInBits == 0) { @@ -1753,19 +1743,30 @@ TrieEntry lowerEntry(final K key) { } final TrieEntry found = getNearestEntryForKey(key, lengthInBits); - if (compareKeys(key, found.key)) { + if (keysAreEqual(key, found.key)) { return previousEntry(found); } final int bitIndex = bitIndex(key, found.key); if (KeyAnalyzer.isValidBitIndex(bitIndex)) { - final TrieEntry added = new TrieEntry<>(key, null, bitIndex); - addEntry(added, lengthInBits); - incrementSize(); // must increment because remove will decrement - final TrieEntry prior = previousEntry(added); - removeEntry(added); - modCount -= 2; // we didn't really modify it. - return prior; + if (isBitSet(key, bitIndex, lengthInBits)) { + TrieEntry floor = found; + TrieEntry next = nextEntry(found); + while (next != null && getKeyAnalyzer().compare(key, next.key) >= 0) { + floor = next; + next = nextEntry(next); + } + return floor; + } else { + TrieEntry prev = previousEntry(found); + while (prev != null && !prev.isEmpty() && getKeyAnalyzer().compare(key, prev.key) < 0) { + prev = previousEntry(prev); + } + if (prev == null || prev.isEmpty()) { + return null; + } + return prev; + } } if (KeyAnalyzer.isNullBitKey(bitIndex)) { return null; @@ -2028,7 +2029,7 @@ public V put(final K key, final V value) { } final TrieEntry found = getNearestEntryForKey(key, lengthInBits); - if (compareKeys(key, found.key)) { + if (keysAreEqual(key, found.key)) { if (found.isEmpty()) { // <- must be the root incrementSize(); } else { @@ -2104,7 +2105,7 @@ public V remove(final Object k) { TrieEntry path = root; while (true) { if (current.bitIndex <= path.bitIndex) { - if (!current.isEmpty() && compareKeys(key, current.key)) { + if (!current.isEmpty() && keysAreEqual(key, current.key)) { return removeEntry(current); } return null; diff --git a/src/test/java/org/apache/commons/collections4/trie/PatriciaTrieTest.java b/src/test/java/org/apache/commons/collections4/trie/PatriciaTrieTest.java index 9da685ef53..6a7c766792 100644 --- a/src/test/java/org/apache/commons/collections4/trie/PatriciaTrieTest.java +++ b/src/test/java/org/apache/commons/collections4/trie/PatriciaTrieTest.java @@ -33,6 +33,14 @@ import java.util.NoSuchElementException; import java.util.Set; import java.util.SortedMap; +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + import org.apache.commons.collections4.Trie; import org.apache.commons.collections4.map.AbstractSortedMapTest; @@ -437,6 +445,229 @@ void testPrefixMapSizes2() { assertTrue(trie.prefixMap(prefixString).containsKey(longerString)); } + @Test + void testSubmap() { + final PatriciaTrie trie = new PatriciaTrie<>(); + trie.put("ga", "ga"); + trie.put("gb", "gb"); + trie.put("gc", "gc"); + trie.put("gd", "gd"); + trie.put("ge", "ge"); + + // subMap should be entire trie + SortedMap subMap = trie.subMap("a", "z"); + assertEquals(5, subMap.size()); + assertEquals("ga", subMap.get("ga")); + assertEquals("gb", subMap.get("gb")); + assertEquals("gc", subMap.get("gc")); + assertEquals("gd", subMap.get("gd")); + assertEquals("ge", subMap.get("ge")); + + // subMap should be empty + subMap = trie.subMap("a", "a"); + assertEquals(0, subMap.size()); + + // subMap() is not inclusive of the second key + // subMap should be 4 entries only - "ge" excluded + subMap = trie.subMap("ga", "ge"); + assertEquals(4, subMap.size()); + assertEquals("ga", subMap.get("ga")); + assertEquals("gb", subMap.get("gb")); + assertEquals("gc", subMap.get("gc")); + assertEquals("gd", subMap.get("gd")); + assertNull(subMap.get("ge")); + + // subMap should be 5 entries + subMap = trie.subMap("ga", "gf"); + assertEquals(5, subMap.size()); + assertEquals("ga", subMap.get("ga")); + assertEquals("gb", subMap.get("gb")); + assertEquals("gc", subMap.get("gc")); + assertEquals("gd", subMap.get("gd")); + assertEquals("ge", subMap.get("ge")); + + // subMap should be 4 entries - "ga" excluded + subMap = trie.subMap("gb", "z"); + assertEquals(4, subMap.size()); + assertNull(subMap.get("ga")); + assertEquals("gb", subMap.get("gb")); + assertEquals("gc", subMap.get("gc")); + assertEquals("gd", subMap.get("gd")); + assertEquals("ge", subMap.get("ge")); + + + // subMap should be 1 entry - "gc" only + subMap = trie.subMap("gc", "gd"); + assertEquals(1, subMap.size()); + assertNull(subMap.get("ga")); + assertNull(subMap.get("gb")); + assertEquals("gc", subMap.get("gc")); + assertNull(subMap.get("gd")); + assertNull(subMap.get("ge")); + } + + @Test + void testTailMap() { + final PatriciaTrie trie = new PatriciaTrie<>(); + trie.put("ga", "ga"); + trie.put("gb", "gb"); + trie.put("gc", "gc"); + trie.put("gd", "gd"); + trie.put("ge", "ge"); + + // tailMap should be entire trie + SortedMap tailMap = trie.tailMap("a"); + assertEquals(5, tailMap.size()); + assertEquals("ga", tailMap.get("ga")); + assertEquals("gb", tailMap.get("gb")); + assertEquals("gc", tailMap.get("gc")); + assertEquals("gd", tailMap.get("gd")); + assertEquals("ge", tailMap.get("ge")); + + // tailMap should be empty + tailMap = trie.tailMap("z"); + assertEquals(0, tailMap.size()); + + // tailMap is inclusive of the search key + // tailMap should be the entire trie + tailMap = trie.tailMap("ga"); + assertEquals(5, tailMap.size()); + assertEquals("ga", tailMap.get("ga")); + assertEquals("gb", tailMap.get("gb")); + assertEquals("gc", tailMap.get("gc")); + assertEquals("gd", tailMap.get("gd")); + assertEquals("ge", tailMap.get("ge")); + + // tailMap should be single entry "ge" + tailMap = trie.tailMap("ge"); + assertEquals(1, tailMap.size()); + assertNull(tailMap.get("ga")); + assertNull(tailMap.get("gb")); + assertNull(tailMap.get("gc")); + assertNull(tailMap.get("gd")); + assertEquals("ge", tailMap.get("ge")); + } + + @Test + void testHeadMap() { + final PatriciaTrie trie = new PatriciaTrie<>(); + trie.put("ga", "ga"); + trie.put("gb", "gb"); + trie.put("gc", "gc"); + trie.put("gd", "gd"); + trie.put("ge", "ge"); + + // headMap should be entire trie + SortedMap headMap = trie.headMap("z"); + assertEquals(5, headMap.size()); + assertEquals("ga", headMap.get("ga")); + assertEquals("gb", headMap.get("gb")); + assertEquals("gc", headMap.get("gc")); + assertEquals("gd", headMap.get("gd")); + assertEquals("ge", headMap.get("ge")); + + // headMap should be empty + headMap = trie.headMap("a"); + assertEquals(0, headMap.size()); + + // headMap() is not inclusive of the key + // headMap should be 4 entries only - "ge" excluded + headMap = trie.headMap("ge"); + assertEquals(4, headMap.size()); + assertEquals("ga", headMap.get("ga")); + assertEquals("gb", headMap.get("gb")); + assertEquals("gc", headMap.get("gc")); + assertEquals("gd", headMap.get("gd")); + assertNull(headMap.get("ge")); + + // headMap should be 5 entries + headMap = trie.headMap("gf"); + assertEquals(5, headMap.size()); + assertEquals("ga", headMap.get("ga")); + assertEquals("gb", headMap.get("gb")); + assertEquals("gc", headMap.get("gc")); + assertEquals("gd", headMap.get("gd")); + assertEquals("ge", headMap.get("ge")); + + // headMap should be 1 entry - "ga" only + headMap = trie.headMap("gb"); + assertEquals(1, headMap.size()); + assertEquals("ga", headMap.get("ga")); + assertNull(headMap.get("gb")); + assertNull(headMap.get("gc")); + assertNull(headMap.get("gd")); + assertNull(headMap.get("ge")); + } + + + @Test + void testConcurrentTrieIterationAndSubMapIteration() throws InterruptedException, ExecutionException, TimeoutException { + final PatriciaTrie trie = new PatriciaTrie<>(); + // populate with enough entries to make concurrent collisions likely + // call subMap with both keys missing to ensure phantom node addition done twice + final int subKeyFirst = 1; + final int subKeySecond = 4; + final String subKeyFirstStr = String.format("key%04d", subKeyFirst); + final String subKeySecondStr = String.format("key%04d", subKeySecond); + for (int i = 0; i <= 501; i++) { + if (i != subKeyFirst && i != subKeySecond) { + trie.put(String.format("key%04d", i), i); + } + } + + final int iterations = 100; + final CyclicBarrier barrier = new CyclicBarrier(2); + + final ExecutorService executor = Executors.newFixedThreadPool(2); + try { + // Thread 1: repeatedly iterate the entire trie + final Future iteratorTask = executor.submit(() -> { + barrier.await(1, TimeUnit.SECONDS); + for (int i = 0; i < iterations && !Thread.currentThread().isInterrupted(); i++) { + int count = 0; + for (final Map.Entry entry : trie.entrySet()) { + // verify the iterated keys and values are not from the phantom node + assertNotNull(entry.getKey()); + assertNotNull(entry.getValue()); + count++; + } + assertEquals(500, count, "Iterator skipped or duplicated entries"); + } + return null; + }); + + // Thread 2: repeatedly create and iterate subMap views + // (this triggers ceilingEntry with keys NOT in the trie) + final Future subMapTask = executor.submit(() -> { + barrier.await(1, TimeUnit.SECONDS); + for (int i = 0; i < iterations && !Thread.currentThread().isInterrupted(); i++) { + // Use boundary keys that do NOT exist in the trie + // to force the ceiling/floor walk algorithm + final SortedMap sub = trie.subMap(subKeyFirstStr, subKeySecondStr); + int count = 0; + for (final Map.Entry entry : sub.entrySet()) { + // verify the iterated keys and values are not from the phantom node + assertNotNull(entry.getKey()); + assertNotNull(entry.getValue()); + count++; + } + assertEquals(2, count, "subMap returned wrong number of entries"); + } + return null; + }); + + // get() unwraps ExecutionException + // if either task threw an exception or an assertion Error, (or any Throwable), + // then the original exception propagates with its full stacktrace + // and TimeoutException surfaces hangs + subMapTask.get(10, TimeUnit.SECONDS); + iteratorTask.get(10, TimeUnit.SECONDS); + } finally { + executor.shutdownNow(); + executor.awaitTermination(5, TimeUnit.SECONDS); + } + } + // void testCreate() throws Exception { // resetEmpty(); // writeExternalFormToDisk(