As an interesting exercise, I decided to write my own Trie implementation. The code is shown below. I’ve tested it with a small amount of data, and it seems to work OK. The next step is to load it up with a large amount of data to see how it performs against the Radix Tree implementation and a binary search. I will save that for a later post.
A data structure like this would work great in conjunction with the YUI Autocomplete control.
The Trie Interface
import java.util.List;
import java.util.Set;
public interface Trie<T> {
public void add(String key, T value);
public T find(String key);
public List<T> search(String prefix);
public boolean contains(String key);
public Set<String> getAllKeys();
public int size();
}
The Trie Implementation
There’s a lot of recursive calls going on here…
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
public class TrieImpl<T> implements Trie<T> {
private TrieNode<T> rootNode = new TrieNode<T>();
public void add(String key, T value) {
addNode(rootNode, key, 0, value);
}
public T find(String key) {
return findKey(rootNode, key);
}
public List<T> search(String prefix) {
List<T> list = new ArrayList<T>();
char[] ch = prefix.toCharArray();
TrieNode<T> node = rootNode;
for (int i = 0; i < ch.length; i++) {
node = node.getChildren().get(ch[i]);
if (node == null) {
break;
}
}
if (node != null) {
getValuesFromNode(node, list);
}
return list;
}
public boolean contains(String key) {
return hasKey(rootNode, key);
}
public Set<String> getAllKeys() {
Set<String> keySet = new HashSet<String>();
getKeysFromNode(rootNode, "", keySet);
return keySet;
}
public int size() {
return getAllKeys().size();
}
private void getValuesFromNode(TrieNode<T> currNode, List<T> valueList) {
if (currNode.isTerminal()) {
valueList.add(currNode.getNodeValue());
}
Map<Character, TrieNode<T>> children = currNode.getChildren();
Iterator childIter = children.keySet().iterator();
while (childIter.hasNext()) {
Character ch = (Character)childIter.next();
TrieNode<T> nextNode = children.get(ch);
getValuesFromNode(nextNode, valueList);
}
}
private void getKeysFromNode(TrieNode<T> currNode, String key, Set keySet) {
if (currNode.isTerminal()) {
keySet.add(key);
}
Map<Character, TrieNode<T>> children = currNode.getChildren();
Iterator childIter = children.keySet().iterator();
while (childIter.hasNext()) {
Character ch = (Character)childIter.next();
TrieNode<T> nextNode = children.get(ch);
String s = key + nextNode.getNodeKey();
getKeysFromNode(nextNode, s, keySet);
}
}
private T findKey(TrieNode<T> currNode, String key) {
Character c = key.charAt(0);
if (currNode.getChildren().containsKey(c)) {
TrieNode<T> nextNode = currNode.getChildren().get(c);
if (key.length() == 1) {
if (nextNode.isTerminal()) {
return nextNode.getNodeValue();
}
} else {
return findKey(nextNode, key.substring(1));
}
}
return null;
}
private boolean hasKey(TrieNode<T> currNode, String key) {
Character c = key.charAt(0);
if (currNode.getChildren().containsKey(c)) {
TrieNode<T> nextNode = currNode.getChildren().get(c);
if (key.length() == 1) {
if (nextNode.isTerminal()) {
return true;
}
} else {
return hasKey(nextNode, key.substring(1));
}
}
return false;
}
private void addNode(TrieNode<T> currNode, String key, int pos, T value) {
Character c = key.charAt(pos);
TrieNode<T> nextNode = currNode.getChildren().get(c);
if (nextNode == null) {
nextNode = new TrieNode<T>();
nextNode.setNodeKey(c);
if (pos < key.length() - 1) {
addNode(nextNode, key, pos + 1, value);
} else {
nextNode.setNodeValue(value);
nextNode.setTerminal(true);
}
currNode.getChildren().put(c, nextNode);
} else {
if (pos < key.length() - 1) {
addNode(nextNode, key, pos + 1, value);
} else {
nextNode.setNodeValue(value);
nextNode.setTerminal(true);
}
}
}
}
The Trie Node
import java.util.HashMap;
import java.util.Map;
public class TrieNode<T> {
private Character nodeKey;
private T nodeValue;
private boolean terminal;
private Map<Character, TrieNode<T>> children = new HashMap<Character, TrieNode<T>>();
public Character getNodeKey() {
return nodeKey;
}
public void setNodeKey(Character nodeKey) {
this.nodeKey = nodeKey;
}
public T getNodeValue() {
return nodeValue;
}
public void setNodeValue(T nodeValue) {
this.nodeValue = nodeValue;
}
public boolean isTerminal() {
return terminal;
}
public void setTerminal(boolean terminal) {
this.terminal = terminal;
}
public Map<Character, TrieNode<T>> getChildren() {
return children;
}
public void setChildren(Map<Character, TrieNode<T>> children) {
this.children = children;
}
}
The Test
import junit.framework.TestCase;
public class TrieTest extends TestCase {
Trie<Product> productTrie = new TrieImpl<Product>();
public void setUp() throws Exception {
productTrie.add("ham", new Product(1, "ham"));
productTrie.add("hammer", new Product(2, "hammer"));
productTrie.add("hammock", new Product(3, "hammock"));
productTrie.add("ipod", new Product(4, "ipod"));
productTrie.add("iphone", new Product(5, "iphone"));
}
public void testAdd() {
assertEquals(5, productTrie.size());
}
public void testFind() {
assertNotNull(productTrie.find("ham"));
}
public void testSearch() {
assertEquals(3, productTrie.search("ha").size());
}
public void testContains() {
assertEquals(true, productTrie.contains("ipod"));
}
public void testGetAllKeys() {
assertEquals(5, productTrie.getAllKeys().size());
}
class Product {
private int productId;
private String productDesc;
public Product(int productId, String productDesc) {
this.productId = productId;
this.productDesc = productDesc;
}
public int getProductId() {
return productId;
}
public void setProductId(int productId) {
this.productId = productId;
}
public String getProductDesc() {
return productDesc;
}
public void setProductDesc(String productDesc) {
this.productDesc = productDesc;
}
}
}


#1 by Ravi Teja Karri on January 27, 2012 - 6:59 am
thank you