Commit 35176088 authored by Jeroen F.J. Laros's avatar Jeroen F.J. Laros

Added counting in iterators.

parent 70c13294
......@@ -67,20 +67,24 @@ def _remove(node, word, count):
return result
def _iterate(path, node):
def _iterate(path, node, unique):
"""Convert a trie into a list.
:arg str path: Path taken so far to reach the current node.
:arg dict node: Current node.
:arg bool unique: Do not list multiplicities.
:returns iter: All words in the trie.
"""
if '' in node:
if not unique:
for _ in range(1, node['']['count']):
yield path
yield path
for char in node:
if char:
for result in _iterate(path + char, node[char]):
for result in _iterate(path + char, node[char], unique):
yield result
......@@ -95,7 +99,7 @@ def _fill(node, alphabet, length):
{alphabet}.
"""
if not length:
node[''] = {}
node[''] = {'count': 1}
return
for char in alphabet:
......@@ -198,7 +202,10 @@ class Trie(object):
return '' in _find(self.root, word)
def __iter__(self):
return _iterate('', self.root)
return _iterate('', self.root, True)
def list(self, unique=True):
return _iterate('', self.root, unique)
def add(self, word, count=1):
_add(self.root, word, count)
......
......@@ -110,6 +110,13 @@ class TestTrie(object):
def test_iter(self):
assert list(self._trie) == ['abc', 'abd', 'te', 'test']
def test_list(self):
assert list(self._trie.list()) == list(self._trie)
def test_list_non_unique(self):
assert list(self._trie.list(False)) == [
'abc', 'abd', 'abd', 'te', 'test']
def test_fill(self):
trie = Trie()
trie.fill(('a', 'b'), 3)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment