carrot/tinygrad_repo/test/extra/test_extra_helpers.py

57 lines
2.0 KiB
Python
Raw Normal View History

#!/usr/bin/env python
import os, cloudpickle, tempfile, unittest, subprocess
from extra.helpers import enable_early_exec, cross_process, _CloudpickleFunctionWrapper
def normalize_line_endings(s): return s.replace(b'\r\n', b'\n')
class TestEarlyExec(unittest.TestCase):
def setUp(self) -> None:
self.early_exec = enable_early_exec()
def early_exec_py_file(self, file_content, exec_args):
with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as temp:
temp.write(file_content)
temp_path = temp.name
try:
output = self.early_exec((["python3", temp_path] + exec_args, None))
return output
finally:
os.remove(temp_path)
def test_enable_early_exec(self):
output = self.early_exec_py_file(b'print("Hello, world!")', [])
self.assertEqual(b"Hello, world!\n", normalize_line_endings(output))
def test_enable_early_exec_with_arg(self):
output = self.early_exec_py_file(b'import sys\nprint("Hello, " + sys.argv[1] + "!")', ["world"])
self.assertEqual(b"Hello, world!\n", normalize_line_endings(output))
def test_enable_early_exec_process_exception(self):
with self.assertRaises(subprocess.CalledProcessError):
self.early_exec_py_file(b'raise Exception("Test exception")', [])
def test_enable_early_exec_type_exception(self):
with self.assertRaises(TypeError):
self.early_exec((["python3"], "print('Hello, world!')"))
class TestCrossProcess(unittest.TestCase):
def test_cross_process(self):
def _iterate():
for i in range(10): yield i
results = list(cross_process(_iterate))
self.assertEqual(list(range(10)), results)
def test_cross_process_exception(self):
def _iterate():
for i in range(10):
if i == 5: raise ValueError("Test exception")
yield i
with self.assertRaises(ValueError): list(cross_process(_iterate))
def test_CloudpickleFunctionWrapper(self):
def add(x, y): return x + y
self.assertEqual(7, cloudpickle.loads(cloudpickle.dumps(_CloudpickleFunctionWrapper(add)))(3, 4))
if __name__ == '__main__':
unittest.main()