Commit ec68c5ec authored by Jeroen F.J. Laros's avatar Jeroen F.J. Laros

Merge branch 'master' of git.lumc.nl:j.f.j.laros/dict-trie

parents f5b13da5 1757a245
...@@ -110,7 +110,7 @@ using the `fill` function. ...@@ -110,7 +110,7 @@ using the `fill` function.
The trie data structure can be accessed via the `root` member variable. The trie data structure can be accessed via the `root` member variable.
```python ```python
>>> trie.root >>> trie.root
{'a': {'a': {'': {}}, 'b': {'': {}}}, 'b': {'a': {'': {}}, 'b': {'': {}}}} {'a': {'a': {'': 1}, 'b': {'': 1}}, 'b': {'a': {'': 1}, 'b': {'': 1}}}
>>> trie.root.keys() >>> trie.root.keys()
['a', 'b'] ['a', 'b']
``` ```
......
...@@ -9,7 +9,7 @@ Licensed under the MIT license, see the LICENSE file. ...@@ -9,7 +9,7 @@ Licensed under the MIT license, see the LICENSE file.
from .dict_trie import Trie from .dict_trie import Trie
__version_info__ = ('0', '0', '2') __version_info__ = ('0', '0', '3')
__version__ = '.'.join(__version_info__) __version__ = '.'.join(__version_info__)
__author__ = 'LUMC, Jeroen F.J. Laros' __author__ = 'LUMC, Jeroen F.J. Laros'
......
import itertools import itertools
def _add(root, word): def _add(root, word, count):
"""Add a word to a trie. """Add a word to a trie.
:arg dict root: Root of the trie. :arg dict root: Root of the trie.
:arg str word: A word. :arg str word: A word.
:arg int count: Multiplicity of `word`.
""" """
node = root node = root
...@@ -14,7 +15,9 @@ def _add(root, word): ...@@ -14,7 +15,9 @@ def _add(root, word):
node[char] = {} node[char] = {}
node = node[char] node = node[char]
node[''] = {} if '' not in node:
node[''] = 0
node[''] += count
def _find(root, word): def _find(root, word):
...@@ -35,25 +38,28 @@ def _find(root, word): ...@@ -35,25 +38,28 @@ def _find(root, word):
return node return node
def _remove(node, word): def _remove(node, word, count):
"""Remove a word from a trie. """Remove a word from a trie.
:arg dict node: Current node. :arg dict node: Current node.
:arg str word: Word to be removed. :arg str word: Word to be removed.
:arg int count: Multiplicity of `word`, force remove if this is -1.
:returns bool: :returns bool: True if the last occurrence of `word` is removed.
""" """
if not word: if not word:
if '' in node: if '' in node:
node.pop('') node[''] -= count
return True if node[''] < 1 or count == -1:
node.pop('')
return True
return False return False
car, cdr = word[0], word[1:] car, cdr = word[0], word[1:]
if car not in node: if car not in node:
return False return False
result = _remove(node[car], cdr) result = _remove(node[car], cdr, count)
if result: if result:
if not node[car]: if not node[car]:
node.pop(car) node.pop(car)
...@@ -61,20 +67,25 @@ def _remove(node, word): ...@@ -61,20 +67,25 @@ def _remove(node, word):
return result return result
def _iterate(path, node): def _iterate(path, node, unique):
"""Convert a trie into a list. """Convert a trie into a list.
:arg str path: Path taken so far to reach the current node. :arg str path: Path taken so far to reach the current node.
:arg dict node: Current node. :arg dict node: Current node.
:arg bool unique: Do not list multiplicities.
:returns iter: All words in a trie. :returns iter: All words in a trie.
""" """
if '' in node: if '' in node:
if not unique:
for _ in range(1, node['']):
yield path
yield path yield path
for char in node: for char in node:
for result in _iterate(path + char, node[char]): if char:
yield result for result in _iterate(path + char, node[char], unique):
yield result
def _fill(node, alphabet, length): def _fill(node, alphabet, length):
...@@ -88,7 +99,7 @@ def _fill(node, alphabet, length): ...@@ -88,7 +99,7 @@ def _fill(node, alphabet, length):
{alphabet}. {alphabet}.
""" """
if not length: if not length:
node[''] = {} node[''] = 1
return return
for char in alphabet: for char in alphabet:
...@@ -117,16 +128,17 @@ def _hamming(path, node, word, distance, cigar): ...@@ -117,16 +128,17 @@ def _hamming(path, node, word, distance, cigar):
car, cdr = word[0], word[1:] car, cdr = word[0], word[1:]
for char in node: for char in node:
if char == car: if char:
penalty = 0 if char == car:
operation = '=' penalty = 0
else: operation = '='
penalty = 1 else:
operation = 'X' penalty = 1
for result in _hamming( operation = 'X'
path + char, node[char], cdr, distance - penalty, for result in _hamming(
cigar + operation): path + char, node[char], cdr, distance - penalty,
yield result cigar + operation):
yield result
def _levenshtein(path, node, word, distance, cigar): def _levenshtein(path, node, word, distance, cigar):
...@@ -155,22 +167,23 @@ def _levenshtein(path, node, word, distance, cigar): ...@@ -155,22 +167,23 @@ def _levenshtein(path, node, word, distance, cigar):
yield result yield result
for char in node: for char in node:
# Substitution. if char:
if car: # Substitution.
if char == car: if car:
penalty = 0 if char == car:
operation = '=' penalty = 0
else: operation = '='
penalty = 1 else:
operation = 'X' penalty = 1
operation = 'X'
for result in _levenshtein(
path + char, node[char], cdr, distance - penalty,
cigar + operation):
yield result
# Insertion.
for result in _levenshtein( for result in _levenshtein(
path + char, node[char], cdr, distance - penalty, path + char, node[char], word, distance - 1, cigar + 'I'):
cigar + operation):
yield result yield result
# Insertion.
for result in _levenshtein(
path + char, node[char], word, distance - 1, cigar + 'I'):
yield result
class Trie(object): class Trie(object):
...@@ -189,13 +202,22 @@ class Trie(object): ...@@ -189,13 +202,22 @@ class Trie(object):
return '' in _find(self.root, word) return '' in _find(self.root, word)
def __iter__(self): def __iter__(self):
return _iterate('', self.root) return _iterate('', self.root, True)
def list(self, unique=True):
return _iterate('', self.root, unique)
def add(self, word): def add(self, word, count=1):
_add(self.root, word) _add(self.root, word, count)
def get(self, word):
node = _find(self.root, word)
if '' in node:
return node['']
return None
def remove(self, word): def remove(self, word, count=1):
return _remove(self.root, word) return _remove(self.root, word, count)
def has_prefix(self, word): def has_prefix(self, word):
return _find(self.root, word) != {} return _find(self.root, word) != {}
......
...@@ -8,7 +8,7 @@ from dict_trie import Trie ...@@ -8,7 +8,7 @@ from dict_trie import Trie
class TestTrie(object): class TestTrie(object):
def setup(self): def setup(self):
self._trie = Trie(['abc', 'abd', 'test', 'te']) self._trie = Trie(['abc', 'abd', 'abd', 'test', 'te'])
def test_empty(self): def test_empty(self):
assert Trie().root == {} assert Trie().root == {}
...@@ -17,11 +17,11 @@ class TestTrie(object): ...@@ -17,11 +17,11 @@ class TestTrie(object):
assert self._trie.root == { assert self._trie.root == {
'a': { 'a': {
'b': { 'b': {
'c': {'': {}}, 'c': {'': 1},
'd': {'': {}}}}, 'd': {'': 2}}},
't': {'e': { 't': {'e': {
'': {}, '': 1,
's': {'t': {'': {}}}}}} 's': {'t': {'': 1}}}}}
def test_word_present(self): def test_word_present(self):
assert 'abc' in self._trie assert 'abc' in self._trie
...@@ -60,6 +60,20 @@ class TestTrie(object): ...@@ -60,6 +60,20 @@ class TestTrie(object):
self._trie.add('abx') self._trie.add('abx')
assert 'abx' in self._trie assert 'abx' in self._trie
def test_get_present(self):
assert self._trie.get('abc') == 1
def test_get_absent(self):
assert not self._trie.get('abx')
def test_add_twice(self):
self._trie.add('abc')
assert self._trie.get('abc') == 2
def test_add_multiple(self):
self._trie.add('abc', 2)
assert self._trie.get('abc') == 3
def test_remove_present(self): def test_remove_present(self):
assert self._trie.remove('test') assert self._trie.remove('test')
assert 'test' not in self._trie assert 'test' not in self._trie
...@@ -76,9 +90,33 @@ class TestTrie(object): ...@@ -76,9 +90,33 @@ class TestTrie(object):
def test_remove_prefix_absent(self): def test_remove_prefix_absent(self):
assert not self._trie.remove('ab') assert not self._trie.remove('ab')
def test_remove_twice(self):
self._trie.add('abc')
assert not self._trie.remove('abc')
assert self._trie.get('abc') == 1
assert self._trie.remove('abc')
assert 'abc' not in self._trie
def test_remove_multile(self):
self._trie.add('abc', 3)
assert not self._trie.remove('abc', 2)
assert self._trie.get('abc') == 2
def test_remove_force(self):
self._trie.add('abc')
assert self._trie.remove('abc', -1)
assert 'abc' not in self._trie
def test_iter(self): def test_iter(self):
assert list(self._trie) == ['abc', 'abd', 'te', 'test'] assert list(self._trie) == ['abc', 'abd', 'te', 'test']
def test_list(self):
assert list(self._trie.list()) == list(self._trie)
def test_list_non_unique(self):
assert list(self._trie.list(False)) == [
'abc', 'abd', 'abd', 'te', 'test']
def test_fill(self): def test_fill(self):
trie = Trie() trie = Trie()
trie.fill(('a', 'b'), 3) trie.fill(('a', 'b'), 3)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment