From 70c13294bc18b4d081d3431c957ee610c88d3116 Mon Sep 17 00:00:00 2001 From: "J.F.J. Laros" Date: Fri, 4 Aug 2017 17:20:00 +0200 Subject: [PATCH] Added counting. --- dict_trie/dict_trie.py | 34 +++++++++++++++++++++++----------- tests/test_trie.py | 41 ++++++++++++++++++++++++++++++++++++----- 2 files changed, 59 insertions(+), 16 deletions(-) diff --git a/dict_trie/dict_trie.py b/dict_trie/dict_trie.py index 9f3ba91..0c36404 100644 --- a/dict_trie/dict_trie.py +++ b/dict_trie/dict_trie.py @@ -1,11 +1,12 @@ import itertools -def _add(root, word): +def _add(root, word, count): """Add a word to the trie. :arg dict root: Root of the trie. :arg str word: A word. + :arg int count: Multiplicity of `word`. """ node = root @@ -14,7 +15,9 @@ def _add(root, word): node[char] = {} node = node[char] - node[''] = {} + if '' not in node: + node[''] = {'count': 0} + node['']['count'] += count def _find(root, word): @@ -35,25 +38,28 @@ def _find(root, word): return node -def _remove(node, word): +def _remove(node, word, count): """Remove a word from a trie. :arg dict node: Current node. :arg str word: Word to be removed. + :arg int count: Multiplicity of `word`, force remove if this is -1. - :returns bool: + :returns bool: True if the last occurrence of `word` is removed. """ if not word: if '' in node: - node.pop('') - return True + node['']['count'] -= count + if node['']['count'] < 1 or count == -1: + node.pop('') + return True return False car, cdr = word[0], word[1:] if car not in node: return False - result = _remove(node[car], cdr) + result = _remove(node[car], cdr, count) if result: if not node[car]: node.pop(car) @@ -194,11 +200,17 @@ class Trie(object): def __iter__(self): return _iterate('', self.root) - def add(self, word): - _add(self.root, word) + def add(self, word, count=1): + _add(self.root, word, count) - def remove(self, word): - return _remove(self.root, word) + def get(self, word): + node = _find(self.root, word) + if '' in node: + return node[''] + return None + + def remove(self, word, count=1): + return _remove(self.root, word, count) def has_prefix(self, word): return _find(self.root, word) != {} diff --git a/tests/test_trie.py b/tests/test_trie.py index eac0bb1..38911b8 100644 --- a/tests/test_trie.py +++ b/tests/test_trie.py @@ -8,7 +8,7 @@ from dict_trie import Trie class TestTrie(object): def setup(self): - self._trie = Trie(['abc', 'abd', 'test', 'te']) + self._trie = Trie(['abc', 'abd', 'abd', 'test', 'te']) def test_empty(self): assert Trie().root == {} @@ -17,11 +17,11 @@ class TestTrie(object): assert self._trie.root == { 'a': { 'b': { - 'c': {'': {}}, - 'd': {'': {}}}}, + 'c': {'': {'count': 1}}, + 'd': {'': {'count': 2}}}}, 't': {'e': { - '': {}, - 's': {'t': {'': {}}}}}} + '': {'count': 1}, + 's': {'t': {'': {'count': 1}}}}}} def test_word_present(self): assert 'abc' in self._trie @@ -60,6 +60,20 @@ class TestTrie(object): self._trie.add('abx') assert 'abx' in self._trie + def test_get_present(self): + assert self._trie.get('abc')['count'] == 1 + + def test_get_absent(self): + assert not self._trie.get('abx') + + def test_add_twice(self): + self._trie.add('abc') + assert self._trie.get('abc')['count'] == 2 + + def test_add_multiple(self): + self._trie.add('abc', 2) + assert self._trie.get('abc')['count'] == 3 + def test_remove_present(self): assert self._trie.remove('test') assert 'test' not in self._trie @@ -76,6 +90,23 @@ class TestTrie(object): def test_remove_prefix_absent(self): assert not self._trie.remove('ab') + def test_remove_twice(self): + self._trie.add('abc') + assert not self._trie.remove('abc') + assert self._trie.get('abc')['count'] == 1 + assert self._trie.remove('abc') + assert 'abc' not in self._trie + + def test_remove_multile(self): + self._trie.add('abc', 3) + assert not self._trie.remove('abc', 2) + assert self._trie.get('abc')['count'] == 2 + + def test_remove_force(self): + self._trie.add('abc') + assert self._trie.remove('abc', -1) + assert 'abc' not in self._trie + def test_iter(self): assert list(self._trie) == ['abc', 'abd', 'te', 'test'] -- GitLab