carrot/tinygrad_repo/test/external/external_test_mamba.py
Vehicle Researcher 4fca6dec8e openpilot v0.9.8 release
date: 2025-01-29T09:09:56
master commit: 227bb68e1891619b360b89809e6822d50d34228f
2025-01-29 09:09:58 +00:00

25 lines
890 B
Python

import unittest
from tinygrad.helpers import CI
from examples.mamba import Mamba, generate
from transformers import AutoTokenizer
PROMPT = 'Why is gravity '
TOKENIZER = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
@unittest.skipIf(CI, "model is slow for CI")
class TestMamba(unittest.TestCase):
def test_mamba_130M(self):
OUT_130M = '''Why is gravity \nnot a good idea?\n\nA:'''
model = Mamba.from_pretrained('130m')
tinyoutput = generate(model, TOKENIZER, PROMPT, n_tokens_to_gen=10)
self.assertEqual(OUT_130M, tinyoutput)
del model
def test_mamba_370M(self):
OUT_370M = '''Why is gravity \nso important?\nBecause it's the only'''
model = Mamba.from_pretrained('370m')
tinyoutput = generate(model, TOKENIZER, PROMPT, n_tokens_to_gen=10)
self.assertEqual(OUT_370M, tinyoutput)
del model
if __name__ == '__main__':
unittest.main()