lrucache.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. from UserDict import DictMixin
  2. from heapq import heappush, heapify, heapreplace, heappop
  3. import unittest
  4. class LRUCache(DictMixin):
  5. """Heap queue based Least Recently Used Cache implementation
  6. """
  7. class Node(object):
  8. """Internal cache node
  9. """
  10. __slots__ = ('key', 'value', 't')
  11. def __init__(self, key, value, t):
  12. self.key = key
  13. self.value = value
  14. self.t = t
  15. def __cmp__(self, other):
  16. return cmp(self.t, other.t)
  17. def __init__(self, size):
  18. self._heap = []
  19. self._dict = {}
  20. self.size = size
  21. self._t = 0
  22. def __setitem__(self, key, value):
  23. self._t += 1
  24. try:
  25. node = self._dict[key]
  26. node.value = value
  27. node.t = self._t
  28. heapify(self._heap)
  29. except KeyError:
  30. node = self.Node(key, value, self._t)
  31. self._dict[key] = node
  32. if len(self) < self.size:
  33. heappush(self._heap, node)
  34. else:
  35. old = heapreplace(self._heap, node)
  36. del self._dict[old.key]
  37. def __getitem__(self, key):
  38. node = self._dict[key]
  39. self[key] = node.value
  40. return node.value
  41. def __delitem__(self, key):
  42. node = self._dict[key]
  43. del self._dict[key]
  44. self._heap.remove(node)
  45. heapify(self._heap)
  46. def __iter__(self):
  47. copy = self._heap[:]
  48. while copy:
  49. yield heappop(copy).key
  50. def iteritems(self):
  51. copy = self._heap[:]
  52. while copy:
  53. node = heappop(copy)
  54. yield node.key, node.value
  55. def keys(self):
  56. return self._dict.keys()
  57. def __contains__(self, key):
  58. return key in self._dict
  59. def __len__(self):
  60. return len(self._heap)
  61. class LRUCacheTestCase(unittest.TestCase):
  62. def test(self):
  63. c = LRUCache(2)
  64. self.assertEqual(len(c), 0)
  65. for i, x in enumerate('abc'):
  66. c[x] = i
  67. self.assertEqual(len(c), 2)
  68. self.assertEqual(list(c), ['b', 'c'])
  69. self.assertEqual(list(c.iteritems()), [('b', 1), ('c', 2)])
  70. self.assertEqual(False, 'a' in c)
  71. self.assertEqual(True, 'b' in c)
  72. self.assertRaises(KeyError, lambda: c['a'])
  73. self.assertEqual(c['b'], 1)
  74. self.assertEqual(c['c'], 2)
  75. c['d'] = 3
  76. self.assertEqual(len(c), 2)
  77. self.assertEqual(c['c'], 2)
  78. self.assertEqual(c['d'], 3)
  79. c['c'] = 22
  80. c['e'] = 4
  81. self.assertEqual(len(c), 2)
  82. self.assertRaises(KeyError, lambda: c['d'])
  83. self.assertEqual(c['c'], 22)
  84. self.assertEqual(c['e'], 4)
  85. del c['c']
  86. self.assertEqual(len(c), 1)
  87. self.assertRaises(KeyError, lambda: c['c'])
  88. self.assertEqual(c['e'], 4)
  89. def suite():
  90. return unittest.TestLoader().loadTestsFromTestCase(LRUCacheTestCase)
  91. if __name__ == '__main__':
  92. unittest.main()