import unittest import pickle from tinygrad.helpers import diskcache_get, diskcache_put def remote_get(table,q,k): q.put(diskcache_get(table, k)) def remote_put(table,k,v): diskcache_put(table, k, v) class DiskCache(unittest.TestCase): def test_putget(self): table = "test_putget" diskcache_put(table, "hello", "world") self.assertEqual(diskcache_get(table, "hello"), "world") diskcache_put(table, "hello", "world2") self.assertEqual(diskcache_get(table, "hello"), "world2") def test_putcomplex(self): table = "test_putcomplex" diskcache_put(table, "k", ("complex", 123, "object")) ret = diskcache_get(table, "k") self.assertEqual(ret, ("complex", 123, "object")) def test_getotherprocess(self): table = "test_getotherprocess" from multiprocessing import Process, Queue diskcache_put(table, "k", "getme") q = Queue() p = Process(target=remote_get, args=(table,q,"k")) p.start() p.join() self.assertEqual(q.get(), "getme") def test_putotherprocess(self): table = "test_putotherprocess" from multiprocessing import Process p = Process(target=remote_put, args=(table,"k", "remote")) p.start() p.join() self.assertEqual(diskcache_get(table, "k"), "remote") def test_no_table(self): self.assertIsNone(diskcache_get("faketable", "k")) def test_ret(self): table = "test_ret" self.assertEqual(diskcache_put(table, "key", ("vvs",)), ("vvs",)) def test_non_str_key(self): table = "test_non_str_key" diskcache_put(table, 4, 5) self.assertEqual(diskcache_get(table, 4), 5) self.assertEqual(diskcache_get(table, "4"), 5) def test_dict_key(self): table = "test_dict_key" fancy_key = {"hello": "world", "goodbye": 7, "good": True, "pkl": pickle.dumps("cat")} fancy_key2 = {"hello": "world", "goodbye": 8, "good": True, "pkl": pickle.dumps("cat")} fancy_key3 = {"hello": "world", "goodbye": 8, "good": True, "pkl": pickle.dumps("dog")} diskcache_put(table, fancy_key, 5) self.assertEqual(diskcache_get(table, fancy_key), 5) diskcache_put(table, fancy_key2, 8) self.assertEqual(diskcache_get(table, fancy_key2), 8) self.assertEqual(diskcache_get(table, fancy_key), 5) self.assertEqual(diskcache_get(table, fancy_key3), None) if __name__ == "__main__": unittest.main()