carrot/tinygrad_repo/test/external/external_test_datasets.py
Vehicle Researcher 4fca6dec8e openpilot v0.9.8 release
date: 2025-01-29T09:09:56
master commit: 227bb68e1891619b360b89809e6822d50d34228f
2025-01-29 09:09:58 +00:00

79 lines
3.3 KiB
Python

from extra.datasets.kits19 import iterate, preprocess
from examples.mlperf.dataloader import batch_load_unet3d
from test.external.mlperf_unet3d.kits19 import PytTrain, PytVal
from tinygrad.helpers import temp
from pathlib import Path
import nibabel as nib
import numpy as np
import os
import random
import tempfile
import unittest
class ExternalTestDatasets(unittest.TestCase):
def _set_seed(self):
np.random.seed(42)
random.seed(42)
def _create_samples(self, val, num_samples=2):
self._set_seed()
img, lbl = np.random.rand(190, 392, 392).astype(np.float32), np.random.randint(0, 100, size=(190, 392, 392)).astype(np.uint8)
img, lbl = nib.Nifti1Image(img, np.eye(4)), nib.Nifti1Image(lbl, np.eye(4))
dataset = "val" if val else "train"
preproc_pth = Path(tempfile.gettempdir() + f"/{dataset}")
for i in range(num_samples):
os.makedirs(tempfile.gettempdir() + f"/case_000{i}", exist_ok=True)
nib.save(img, temp(f"case_000{i}/imaging.nii.gz"))
nib.save(lbl, temp(f"case_000{i}/segmentation.nii.gz"))
preproc_img, preproc_lbl = preprocess(Path(tempfile.gettempdir()) / f"case_000{i}")
preproc_img_pth, preproc_lbl_pth = temp(f"{dataset}/case_000{i}_x.npy"), temp(f"{dataset}/case_000{i}_y.npy")
os.makedirs(preproc_pth, exist_ok=True)
np.save(preproc_img_pth, preproc_img, allow_pickle=False)
np.save(preproc_lbl_pth, preproc_lbl, allow_pickle=False)
return preproc_pth, list(preproc_pth.glob("*_x.npy")), list(preproc_pth.glob("*_y.npy"))
def _create_kits19_ref_dataloader(self, preproc_img_pths, preproc_lbl_pths, val):
if val:
dataset = PytVal(preproc_img_pths, preproc_lbl_pths)
else:
dataset = PytTrain(preproc_img_pths, preproc_lbl_pths, patch_size=(128, 128, 128), oversampling=0.4)
return iter(dataset)
def _create_kits19_tinygrad_dataloader(self, preproc_pth, val, batch_size=1, shuffle=False, seed=42, use_old_dataloader=False):
if use_old_dataloader:
dataset = iterate(list(Path(tempfile.gettempdir()).glob("case_*")), preprocessed_dir=preproc_pth, val=val, shuffle=shuffle, bs=batch_size)
else:
dataset = iter(batch_load_unet3d(preproc_pth, batch_size=batch_size, val=val, shuffle=shuffle, seed=seed))
return iter(dataset)
def test_kits19_training_set(self):
preproc_pth, preproc_img_pths, preproc_lbl_pths = self._create_samples(False)
ref_dataset = self._create_kits19_ref_dataloader(preproc_img_pths, preproc_lbl_pths, False)
tinygrad_dataset = self._create_kits19_tinygrad_dataloader(preproc_pth, False)
for ref_sample, tinygrad_sample in zip(ref_dataset, tinygrad_dataset):
self._set_seed()
np.testing.assert_equal(tinygrad_sample[0][:, 0].numpy(), ref_sample[0])
np.testing.assert_equal(tinygrad_sample[1][:, 0].numpy(), ref_sample[1])
def test_kits19_validation_set(self):
preproc_pth, preproc_img_pths, preproc_lbl_pths = self._create_samples(True)
ref_dataset = self._create_kits19_ref_dataloader(preproc_img_pths, preproc_lbl_pths, True)
tinygrad_dataset = self._create_kits19_tinygrad_dataloader(preproc_pth, True, use_old_dataloader=True)
for ref_sample, tinygrad_sample in zip(ref_dataset, tinygrad_dataset):
np.testing.assert_equal(tinygrad_sample[0][:, 0], ref_sample[0])
np.testing.assert_equal(tinygrad_sample[1], ref_sample[1])
if __name__ == '__main__':
unittest.main()