2025-01-07 19:31:23 +09:00
import os , json , pathlib , zipfile , pickle
from tqdm import tqdm
from typing import Dict , Union , List , Optional , Any , Tuple
2023-11-17 23:53:40 +00:00
from tinygrad . tensor import Tensor
2025-01-07 19:31:23 +09:00
from tinygrad . helpers import dtypes , prod , argsort , DEBUG , Timing , GlobalCounters , CI
2023-11-17 23:53:40 +00:00
from tinygrad . shape . view import strides_for_shape
2025-01-07 19:31:23 +09:00
from tinygrad . ops import Device
2023-11-17 23:53:40 +00:00
2025-01-07 19:31:23 +09:00
safe_dtypes = { " F16 " : dtypes . float16 , " F32 " : dtypes . float32 , " U8 " : dtypes . uint8 , " I8 " : dtypes . int8 , " I32 " : dtypes . int32 , " I64 " : dtypes . int64 }
2023-11-17 23:53:40 +00:00
inverse_safe_dtypes = { v : k for k , v in safe_dtypes . items ( ) }
2025-01-07 19:31:23 +09:00
def safe_load_metadata ( fn : Union [ Tensor , str ] ) - > Tuple [ Tensor , int , Any ] :
t = fn if isinstance ( fn , Tensor ) else Tensor . empty ( os . stat ( fn ) . st_size , dtype = dtypes . uint8 , device = f " disk: { fn } " )
json_len = t [ 0 : 1 ] . cast ( dtypes . int64 ) . numpy ( ) [ 0 ]
return ( t , json_len , json . loads ( t [ 8 : 8 + json_len ] . numpy ( ) . tobytes ( ) ) )
2023-11-17 23:53:40 +00:00
2025-01-07 19:31:23 +09:00
def safe_load ( fn : Union [ Tensor , str ] ) - > Dict [ str , Tensor ] :
t , json_len , metadata = safe_load_metadata ( fn )
return { k : t [ 8 + json_len + v [ ' data_offsets ' ] [ 0 ] : ] . cast ( safe_dtypes [ v [ ' dtype ' ] ] ) [ : prod ( v [ ' shape ' ] ) ] . reshape ( v [ ' shape ' ] ) for k , v in metadata . items ( ) if k != " __metadata__ " }
2023-11-17 23:53:40 +00:00
2025-01-07 19:31:23 +09:00
def safe_save ( tensors : Dict [ str , Tensor ] , fn : str , metadata : Optional [ Dict [ str , Any ] ] = None ) :
2023-11-17 23:53:40 +00:00
headers , offset = { } , 0
if metadata : headers [ ' __metadata__ ' ] = metadata
for k , v in tensors . items ( ) :
headers [ k ] = { ' dtype ' : inverse_safe_dtypes [ v . dtype ] , ' shape ' : list ( v . shape ) , ' data_offsets ' : [ offset , offset + v . nbytes ( ) ] }
offset + = v . nbytes ( )
j = json . dumps ( headers , separators = ( ' , ' , ' : ' ) )
2025-01-07 19:31:23 +09:00
j + = " \x20 " * ( ( 8 - len ( j ) % 8 ) % 8 )
2023-11-17 23:53:40 +00:00
pathlib . Path ( fn ) . unlink ( missing_ok = True )
t = Tensor . empty ( 8 + len ( j ) + offset , dtype = dtypes . uint8 , device = f " disk: { fn } " )
2025-01-07 19:31:23 +09:00
t [ 0 : 1 ] . cast ( dtypes . int64 ) . assign ( [ len ( j ) ] )
t [ 8 : 8 + len ( j ) ] . assign ( Tensor ( list ( j . encode ( ' utf-8 ' ) ) , dtype = dtypes . uint8 , device = " cpu " ) )
2023-11-17 23:53:40 +00:00
for k , v in safe_load ( t ) . items ( ) : v . assign ( tensors [ k ] )
# state dict
2025-01-07 19:31:23 +09:00
from collections import OrderedDict
def get_state_dict ( obj , prefix : str = ' ' , tensor_type = Tensor ) - > Dict [ str , Tensor ] :
2023-11-17 23:53:40 +00:00
if isinstance ( obj , tensor_type ) : return { prefix . strip ( ' . ' ) : obj }
if hasattr ( obj , ' _asdict ' ) : return get_state_dict ( obj . _asdict ( ) , prefix , tensor_type ) # namedtuple
if isinstance ( obj , OrderedDict ) : return get_state_dict ( dict ( obj ) , prefix , tensor_type )
if hasattr ( obj , ' __dict__ ' ) : return get_state_dict ( obj . __dict__ , prefix , tensor_type )
state_dict = { }
if isinstance ( obj , ( list , tuple ) ) :
for i , x in enumerate ( obj ) : state_dict . update ( get_state_dict ( x , f " { prefix } { str ( i ) } . " , tensor_type ) )
elif isinstance ( obj , dict ) :
for k , v in obj . items ( ) : state_dict . update ( get_state_dict ( v , f " { prefix } { str ( k ) } . " , tensor_type ) )
return state_dict
2025-01-07 19:31:23 +09:00
def get_parameters ( obj ) - > List [ Tensor ] : return list ( get_state_dict ( obj ) . values ( ) )
2023-11-17 23:53:40 +00:00
2025-01-07 19:31:23 +09:00
def load_state_dict ( model , state_dict , strict = True , verbose = True ) :
with Timing ( " loaded weights in " , lambda et_ns : f " , { GlobalCounters . mem_used / 1e9 : .2f } GB loaded at { GlobalCounters . mem_used / et_ns : .2f } GB/s " ) :
2023-11-17 23:53:40 +00:00
model_state_dict = get_state_dict ( model )
2025-01-07 19:31:23 +09:00
if DEBUG > = 1 and len ( state_dict ) > len ( model_state_dict ) : print ( " WARNING: unused weights in state_dict " , sorted ( list ( state_dict . keys ( ) - model_state_dict . keys ( ) ) ) )
2023-11-17 23:53:40 +00:00
for k , v in ( t := tqdm ( model_state_dict . items ( ) , disable = CI or not verbose ) ) :
2025-01-07 19:31:23 +09:00
t . set_description ( f " ram used: { GlobalCounters . mem_used / 1e9 : 5.2f } GB, { k : 50s } " )
2023-11-17 23:53:40 +00:00
if k not in state_dict and not strict :
if DEBUG > = 1 : print ( f " WARNING: not loading { k } " )
continue
2025-01-07 19:31:23 +09:00
v . assign ( state_dict [ k ] . to ( v . device ) ) . realize ( )
2023-11-17 23:53:40 +00:00
# torch support!
2025-01-07 19:31:23 +09:00
def torch_load ( fn : str ) :
t = Tensor . empty ( os . stat ( fn ) . st_size , dtype = dtypes . uint8 , device = f " disk: { fn } " )
2023-11-17 23:53:40 +00:00
2025-01-07 19:31:23 +09:00
offsets : Dict [ str , int ] = { }
lens : Dict [ str , int ] = { }
def _rebuild_tensor_v2 ( storage , storage_offset , size , stride , requires_grad , backward_hooks , metadata = None ) :
2023-11-17 23:53:40 +00:00
#print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)
lens [ storage [ 2 ] ] = storage [ 4 ] * storage [ 1 ] . itemsize
if storage [ 2 ] not in offsets : return None
byte_offset = offsets [ storage [ 2 ] ] + storage_offset * storage [ 1 ] . itemsize
2025-01-07 19:31:23 +09:00
ret = t [ byte_offset : byte_offset + prod ( size ) ] . cast ( storage [ 1 ] )
# convert bfloat16 -> float16 using LLVM for Llama 2
# upstream LLaMA also does this conversion:
# https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/generation.py#L95
# TODO: should this be done in the example instead? or maybe we don't need this anymore with better bfloat16 support
if storage [ 1 ] == dtypes . bfloat16 :
ret = ret . bitcast ( dtypes . uint16 ) . to ( " CPU " ) . cast ( dtypes . uint32 ) . mul ( 1 << 16 ) . bitcast ( dtypes . float32 ) . to ( Device . DEFAULT ) . half ( )
#ret = ret.to("LLVM").half().to(Device.DEFAULT)
2023-11-17 23:53:40 +00:00
# 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk
shape_strides = [ ( s , st ) for s , st in zip ( size , stride ) if s != 1 ]
permute_indexes = [ len ( shape_strides ) - 1 - y for y in argsort ( [ x [ 1 ] for x in shape_strides ] ) ]
if tuple ( permute_indexes ) != tuple ( range ( len ( permute_indexes ) ) ) :
intermediate_shape = tuple ( [ shape_strides [ x ] [ 0 ] for x in argsort ( permute_indexes ) ] )
assert tuple ( [ shape_strides [ i ] [ 1 ] for i in argsort ( permute_indexes ) ] ) == strides_for_shape ( intermediate_shape ) , " nonpermutable strides "
2025-01-07 19:31:23 +09:00
if DEBUG > = 2 : print ( f " WARNING: this torch load is slow. CPU to permute { intermediate_shape } with { permute_indexes } " )
2023-11-17 23:53:40 +00:00
# TODO: find a nice way to support all shapetracker on disktensors
2025-01-07 19:31:23 +09:00
ret = ret . cpu ( ) . reshape ( intermediate_shape ) . permute ( permute_indexes )
2023-11-17 23:53:40 +00:00
return ret . reshape ( size )
2025-01-07 19:31:23 +09:00
intercept = { " HalfStorage " : dtypes . float16 , " FloatStorage " : dtypes . float32 , " BFloat16Storage " : dtypes . bfloat16 , " IntStorage " : dtypes . int32 , " LongStorage " : dtypes . int64 , " _rebuild_tensor_v2 " : _rebuild_tensor_v2 }
2023-11-17 23:53:40 +00:00
whitelist = { " torch " , " collections " , " numpy " , " _codecs " } # NOTE: this is not for security, only speed
class Dummy : pass
class TorchPickle ( pickle . Unpickler ) :
def find_class ( self , module , name ) :
module_root = module . split ( " . " ) [ 0 ]
if module_root not in whitelist :
if DEBUG > = 2 : print ( f " WARNING: returning Dummy for { module } { name } " )
return Dummy
return intercept [ name ] if module_root == " torch " else super ( ) . find_class ( module , name )
2025-01-07 19:31:23 +09:00
def persistent_load ( self , pid ) : return pid
2023-11-17 23:53:40 +00:00
2025-01-07 19:31:23 +09:00
if tuple ( t [ 0 : 2 ] . numpy ( ) ) == ( 0x50 , 0x4b ) :
myzip = zipfile . ZipFile ( fn , ' r ' )
2023-11-17 23:53:40 +00:00
base_name = myzip . namelist ( ) [ 0 ] . split ( ' / ' , 1 ) [ 0 ]
for n in myzip . namelist ( ) :
if n . startswith ( f ' { base_name } /data/ ' ) :
with myzip . open ( n ) as myfile :
offsets [ n . split ( " / " ) [ - 1 ] ] = myfile . _orig_compress_start # type: ignore
with myzip . open ( f ' { base_name } /data.pkl ' ) as myfile :
return TorchPickle ( myfile ) . load ( )
else :
2025-01-07 19:31:23 +09:00
with open ( fn , " rb " ) as f :
pkl = TorchPickle ( f )
_ , _ , _ , rwd , _ , ids , base_offset = pkl . load ( ) , pkl . load ( ) , pkl . load ( ) , f . tell ( ) , pkl . load ( ) , pkl . load ( ) , f . tell ( )
for i in ids :
offsets [ i ] = base_offset + 8
base_offset + = 8 + lens [ i ]
f . seek ( rwd )
return TorchPickle ( f ) . load ( )