carrot/tinygrad_repo/extra/optimization/pretrain_valuenet.py

89 lines
2.6 KiB
Python
Raw Normal View History

from tinygrad.codegen.kernel import Kernel
from tqdm import tqdm, trange
import math
import random
from tinygrad.tensor import Tensor
from tinygrad.nn import Linear
from tinygrad.nn.optim import Adam
from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict
# stuff needed to unpack a kernel
from tinygrad.uop.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer
from tinygrad.dtype import dtypes
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.uop.ops import Variable
inf, nan = float('inf'), float('nan')
from tinygrad.codegen.kernel import Opt, OptOps
from extra.optimization.helpers import lin_to_feats, MAX_DIMS
# NOTE: this is not real value of the state, it's just a prediction of the runtime
INNER = 512
class ValueNet:
def __init__(self, feats=240, out=1):
self.l1 = Linear(feats,INNER)
self.l2 = Linear(INNER,INNER)
self.l3 = Linear(INNER,INNER)
self.l4 = Linear(INNER,out)
def __call__(self, x):
x = self.l1(x).relu()
x = self.l2(x).relu()
x = self.l3(x).relu().dropout(0.8)
return self.l4(x)
if __name__ == "__main__":
net = ValueNet()
optim = Adam(get_parameters(net))
TEST_SIZE = 256
dset = open("/tmp/logtm").read().strip().split("\n")
random.seed(1337)
random.shuffle(dset)
X,Y = [], []
for i,x in enumerate(tqdm(dset)):
ast, opts, tms = eval(x)
lin = Kernel(ast)
for o in opts: lin.apply_opt(o)
if lin.shape_len >= MAX_DIMS: continue
if min(tms) == float('inf'): continue
X.append(lin_to_feats(lin))
Y.append([math.log(min(tms))])
print(f"got {len(X)} samples")
X_test,Y_test = Tensor(X[-TEST_SIZE:]), Tensor(Y[-TEST_SIZE:])
X,Y = X[:-TEST_SIZE], Y[:-TEST_SIZE]
def get_minibatch(X,Y,bs):
xs, ys = [], []
for _ in range(bs):
sel = random.randint(0, len(X)-1)
xs.append(X[sel])
ys.append(Y[sel])
return Tensor(xs), Tensor(ys)
Tensor.no_grad, Tensor.training = False, True
losses = []
test_losses = []
test_loss = float('inf')
for i in (t:=trange(2000)):
x,y = get_minibatch(X,Y,bs=256)
out = net(x)
loss = (out-y).square().mean()
optim.zero_grad()
loss.backward()
optim.step()
t.set_description(f"loss {loss.numpy():7.2f}, test loss {test_loss:7.2f}")
losses.append(loss.numpy().item())
test_losses.append(test_loss)
if i % 10: test_loss = (net(X_test)-Y_test).square().mean().numpy().item()
safe_save(get_state_dict(net), "/tmp/valuenet.safetensors")
import matplotlib.pyplot as plt
plt.plot(losses[200:])
plt.plot(test_losses[200:])
plt.show()