import ast import pathlib import sys import unittest import numpy as np from PIL import Image from tinygrad.helpers import getenv from tinygrad.tensor import Tensor from models.efficientnet import EfficientNet from models.vit import ViT from models.resnet import ResNet50 def _load_labels(): labels_filename = pathlib.Path(__file__).parent / 'efficientnet/imagenet1000_clsidx_to_labels.txt' return ast.literal_eval(labels_filename.read_text()) _LABELS = _load_labels() def preprocess(img, new=False): # preprocess image aspect_ratio = img.size[0] / img.size[1] img = img.resize((int(224*max(aspect_ratio,1.0)), int(224*max(1.0/aspect_ratio,1.0)))) img = np.array(img) y0, x0 =(np.asarray(img.shape)[:2] - 224) // 2 img = img[y0: y0 + 224, x0: x0 + 224] # low level preprocess if new: img = img.astype(np.float32) img -= [127.0, 127.0, 127.0] img /= [128.0, 128.0, 128.0] img = img[None] else: img = np.moveaxis(img, [2, 0, 1], [0, 1, 2]) img = img.astype(np.float32)[:3].reshape(1, 3, 224, 224) img /= 255.0 img -= np.array([0.485, 0.456, 0.406]).reshape((1, -1, 1, 1)) img /= np.array([0.229, 0.224, 0.225]).reshape((1, -1, 1, 1)) return img def _infer(model: EfficientNet, img, bs=1): Tensor.training = False img = preprocess(img) # run the net if bs > 1: img = img.repeat(bs, axis=0) out = model.forward(Tensor(img)).cpu() return _LABELS[np.argmax(out.numpy()[0])] chicken_img = Image.open(pathlib.Path(__file__).parent / 'efficientnet/Chicken.jpg') car_img = Image.open(pathlib.Path(__file__).parent / 'efficientnet/car.jpg') class TestEfficientNet(unittest.TestCase): @classmethod def setUpClass(cls): cls.model = EfficientNet(number=getenv("NUM")) cls.model.load_from_pretrained() @classmethod def tearDownClass(cls): del cls.model def test_chicken(self): label = _infer(self.model, chicken_img) self.assertEqual(label, "hen") def test_chicken_bigbatch(self): label = _infer(self.model, chicken_img, 2) self.assertEqual(label, "hen") def test_car(self): label = _infer(self.model, car_img) self.assertEqual(label, "sports car, sport car") class TestViT(unittest.TestCase): @classmethod def setUpClass(cls): cls.model = ViT() cls.model.load_from_pretrained() @classmethod def tearDownClass(cls): del cls.model def test_chicken(self): label = _infer(self.model, chicken_img) self.assertEqual(label, "cock") def test_car(self): label = _infer(self.model, car_img) self.assertEqual(label, "racer, race car, racing car") class TestResNet(unittest.TestCase): @classmethod def setUpClass(cls): cls.model = ResNet50() cls.model.load_from_pretrained() @classmethod def tearDownClass(cls): del cls.model def test_chicken(self): label = _infer(self.model, chicken_img) self.assertEqual(label, "hen") def test_car(self): label = _infer(self.model, car_img) self.assertEqual(label, "sports car, sport car") if __name__ == '__main__': unittest.main()