import sys
if sys.version_info.major < 3:
from itertools import imap as map
def _add(root, word, count):
"""Add a word to a trie.
:arg dict root: Root of the trie.
:arg str word: A word.
:arg int count: Multiplicity of `word`.
"""
node = root
for char in word:
if char not in node:
node[char] = {}
node = node[char]
if '' not in node:
node[''] = 0
node[''] += count
def _find(root, word):
"""Find the node after following the path in a trie given by {word}.
:arg dict root: Root of the trie.
:arg str word: A word.
:returns dict: The node if found, {} otherwise.
"""
node = root
for char in word:
if char not in node:
return {}
node = node[char]
return node
def _remove(node, word, count):
"""Remove a word from a trie.
:arg dict node: Current node.
:arg str word: Word to be removed.
:arg int count: Multiplicity of `word`, force remove if this is -1.
:returns bool: True if the last occurrence of `word` is removed.
"""
if not word:
if '' in node:
node[''] -= count
if node[''] < 1 or count == -1:
node.pop('')
return True
return False
car, cdr = word[0], word[1:]
if car not in node:
return False
result = _remove(node[car], cdr, count)
if result:
if not node[car]:
node.pop(car)
return result
def _iterate(path, node, unique):
"""Convert a trie into a list.
:arg str path: Path taken so far to reach the current node.
:arg dict node: Current node.
:arg bool unique: Do not list multiplicities.
:returns iter: All words in a trie.
"""
if '' in node:
if not unique:
for _ in range(1, node['']):
yield path
yield path
for char in node:
if char:
for result in _iterate(path + char, node[char], unique):
yield result
def _fill(node, alphabet, length):
"""Make a full trie using the characters in {alphabet}.
:arg dict node: Current node.
:arg tuple alphabet: Used alphabet.
:arg int length: Length of the words to be generated.
:returns iter: Trie containing all words of length {length} over alphabet
{alphabet}.
"""
if not length:
node[''] = 1
return
for char in alphabet:
node[char] = {}
_fill(node[char], alphabet, length - 1)
def _hamming(path, node, word, distance, cigar):
"""Find all paths in a trie that are within a certain hamming distance of
{word}.
:arg str path: Path taken so far to reach the current node.
:arg dict node: Current node.
:arg str word: Query word.
:arg int distance: Amount of allowed errors.
:returns iter: All word in a trie that have Hamming distance of at most
{distance} to {word}.
"""
if distance < 0:
return
if not word:
if '' in node:
yield (path, distance, cigar)
return
car, cdr = word[0], word[1:]
for char in node:
if char:
if char == car:
penalty = 0
operation = '='
else:
penalty = 1
operation = 'X'
for result in _hamming(
path + char, node[char], cdr, distance - penalty,
cigar + operation):
yield result
def _levenshtein(path, node, word, distance, cigar):
"""Find all paths in a trie that are within a certain Levenshtein
distance of {word}.
:arg str path: Path taken so far to reach the current node.
:arg dict node: Current node.
:arg str word: Query word.
:arg int distance: Amount of allowed errors.
:returns iter: All word in a trie that have Hamming distance of at most
{distance} to {word}.
"""
if distance < 0:
return
if not word:
if '' in node:
yield (path, distance, cigar)
car, cdr = '', ''
else:
car, cdr = word[0], word[1:]
# Deletion.
for result in _levenshtein(path, node, cdr, distance - 1, cigar + 'D'):
yield result
for char in node:
if char:
# Substitution.
if car:
if char == car:
penalty = 0
operation = '='
else:
penalty = 1
operation = 'X'
for result in _levenshtein(
path + char, node[char], cdr, distance - penalty,
cigar + operation):
yield result
# Insertion.
for result in _levenshtein(
path + char, node[char], word, distance - 1, cigar + 'I'):
yield result
class Trie(object):
def __init__(self, words=None):
"""Initialise the class.
:arg list words: List of words.
"""
self.root = {}
if words:
for word in words:
self.add(word)
def __contains__(self, word):
return '' in _find(self.root, word)
def __iter__(self):
return _iterate('', self.root, True)
def list(self, unique=True):
return _iterate('', self.root, unique)
def add(self, word, count=1):
_add(self.root, word, count)
def get(self, word):
node = _find(self.root, word)
if '' in node:
return node['']
return None
def remove(self, word, count=1):
return _remove(self.root, word, count)
def has_prefix(self, word):
return _find(self.root, word) != {}
def fill(self, alphabet, length):
_fill(self.root, alphabet, length)
def all_hamming_(self, word, distance):
return map(
lambda x: (x[0], distance - x[1], x[2]),
_hamming('', self.root, word, distance, ''))
def all_hamming(self, word, distance):
return map(
lambda x: x[0], _hamming('', self.root, word, distance, ''))
def hamming(self, word, distance):
try:
return next(self.all_hamming(word, distance))
except StopIteration:
return ''
def best_hamming(self, word, distance):
"""Find the best match with {word} in a trie.
:arg str word: Query word.
:arg int distance: Maximum allowed distance.
:returns str: Best match with {word}.
"""
if _find(self.root, word):
return word
for i in range(1, distance + 1):
result = self.hamming(word, i)
if result:
return result
return ''
def all_levenshtein_(self, word, distance):
return map(
lambda x: (x[0], distance - x[1], x[2]),
_levenshtein('', self.root, word, distance, ''))
def all_levenshtein(self, word, distance):
return map(
lambda x: x[0], _levenshtein('', self.root, word, distance, ''))
def levenshtein(self, word, distance):
try:
return next(self.all_levenshtein(word, distance))
except StopIteration:
return ''
def best_levenshtein(self, word, distance):
"""Find the best match with {word} in a trie.
:arg str word: Query word.
:arg int distance: Maximum allowed distance.
:returns str: Best match with {word}.
"""
if _find(self.root, word):
return word
for i in range(1, distance + 1):
result = self.levenshtein(word, i)
if result:
return result
return ''