Commit 70c13294 authored by Jeroen F.J. Laros's avatar Jeroen F.J. Laros

Added counting.

parent c1c3d4ed
import itertools import itertools
def _add(root, word): def _add(root, word, count):
"""Add a word to the trie. """Add a word to the trie.
:arg dict root: Root of the trie. :arg dict root: Root of the trie.
:arg str word: A word. :arg str word: A word.
:arg int count: Multiplicity of `word`.
""" """
node = root node = root
...@@ -14,7 +15,9 @@ def _add(root, word): ...@@ -14,7 +15,9 @@ def _add(root, word):
node[char] = {} node[char] = {}
node = node[char] node = node[char]
node[''] = {} if '' not in node:
node[''] = {'count': 0}
node['']['count'] += count
def _find(root, word): def _find(root, word):
...@@ -35,25 +38,28 @@ def _find(root, word): ...@@ -35,25 +38,28 @@ def _find(root, word):
return node return node
def _remove(node, word): def _remove(node, word, count):
"""Remove a word from a trie. """Remove a word from a trie.
:arg dict node: Current node. :arg dict node: Current node.
:arg str word: Word to be removed. :arg str word: Word to be removed.
:arg int count: Multiplicity of `word`, force remove if this is -1.
:returns bool: :returns bool: True if the last occurrence of `word` is removed.
""" """
if not word: if not word:
if '' in node: if '' in node:
node.pop('') node['']['count'] -= count
return True if node['']['count'] < 1 or count == -1:
node.pop('')
return True
return False return False
car, cdr = word[0], word[1:] car, cdr = word[0], word[1:]
if car not in node: if car not in node:
return False return False
result = _remove(node[car], cdr) result = _remove(node[car], cdr, count)
if result: if result:
if not node[car]: if not node[car]:
node.pop(car) node.pop(car)
...@@ -194,11 +200,17 @@ class Trie(object): ...@@ -194,11 +200,17 @@ class Trie(object):
def __iter__(self): def __iter__(self):
return _iterate('', self.root) return _iterate('', self.root)
def add(self, word): def add(self, word, count=1):
_add(self.root, word) _add(self.root, word, count)
def remove(self, word): def get(self, word):
return _remove(self.root, word) node = _find(self.root, word)
if '' in node:
return node['']
return None
def remove(self, word, count=1):
return _remove(self.root, word, count)
def has_prefix(self, word): def has_prefix(self, word):
return _find(self.root, word) != {} return _find(self.root, word) != {}
......
...@@ -8,7 +8,7 @@ from dict_trie import Trie ...@@ -8,7 +8,7 @@ from dict_trie import Trie
class TestTrie(object): class TestTrie(object):
def setup(self): def setup(self):
self._trie = Trie(['abc', 'abd', 'test', 'te']) self._trie = Trie(['abc', 'abd', 'abd', 'test', 'te'])
def test_empty(self): def test_empty(self):
assert Trie().root == {} assert Trie().root == {}
...@@ -17,11 +17,11 @@ class TestTrie(object): ...@@ -17,11 +17,11 @@ class TestTrie(object):
assert self._trie.root == { assert self._trie.root == {
'a': { 'a': {
'b': { 'b': {
'c': {'': {}}, 'c': {'': {'count': 1}},
'd': {'': {}}}}, 'd': {'': {'count': 2}}}},
't': {'e': { 't': {'e': {
'': {}, '': {'count': 1},
's': {'t': {'': {}}}}}} 's': {'t': {'': {'count': 1}}}}}}
def test_word_present(self): def test_word_present(self):
assert 'abc' in self._trie assert 'abc' in self._trie
...@@ -60,6 +60,20 @@ class TestTrie(object): ...@@ -60,6 +60,20 @@ class TestTrie(object):
self._trie.add('abx') self._trie.add('abx')
assert 'abx' in self._trie assert 'abx' in self._trie
def test_get_present(self):
assert self._trie.get('abc')['count'] == 1
def test_get_absent(self):
assert not self._trie.get('abx')
def test_add_twice(self):
self._trie.add('abc')
assert self._trie.get('abc')['count'] == 2
def test_add_multiple(self):
self._trie.add('abc', 2)
assert self._trie.get('abc')['count'] == 3
def test_remove_present(self): def test_remove_present(self):
assert self._trie.remove('test') assert self._trie.remove('test')
assert 'test' not in self._trie assert 'test' not in self._trie
...@@ -76,6 +90,23 @@ class TestTrie(object): ...@@ -76,6 +90,23 @@ class TestTrie(object):
def test_remove_prefix_absent(self): def test_remove_prefix_absent(self):
assert not self._trie.remove('ab') assert not self._trie.remove('ab')
def test_remove_twice(self):
self._trie.add('abc')
assert not self._trie.remove('abc')
assert self._trie.get('abc')['count'] == 1
assert self._trie.remove('abc')
assert 'abc' not in self._trie
def test_remove_multile(self):
self._trie.add('abc', 3)
assert not self._trie.remove('abc', 2)
assert self._trie.get('abc')['count'] == 2
def test_remove_force(self):
self._trie.add('abc')
assert self._trie.remove('abc', -1)
assert 'abc' not in self._trie
def test_iter(self): def test_iter(self):
assert list(self._trie) == ['abc', 'abd', 'te', 'test'] assert list(self._trie) == ['abc', 'abd', 'te', 'test']
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment