FrogAi 659adb6457 openpilot v0.9.7 release
date: 2024-03-17T10:14:38
master commit: 7e9a909e0e57ecb31df4c87c5b9a06b1204fd034
2024-05-24 17:43:27 -07:00

266 lines
12 KiB
Python

# thanks to https://github.com/openai/whisper for a good chunk of MIT licensed code
import sys
import pathlib
import base64
import multiprocessing
import numpy as np
from typing import Optional
from extra.utils import download_file
from tinygrad.nn.state import torch_load, load_state_dict
from tinygrad.helpers import getenv
import tinygrad.nn as nn
from tinygrad.tensor import Tensor
import itertools
import librosa
# TODO: you have written this fifteen times
class MultiHeadAttention:
def __init__(self, n_state, n_head):
self.n_head = n_head
self.query = nn.Linear(n_state, n_state)
self.key = nn.Linear(n_state, n_state, bias=False)
self.value = nn.Linear(n_state, n_state)
self.out = nn.Linear(n_state, n_state)
def __call__(self, x:Tensor, xa:Optional[Tensor]=None, mask:Optional[Tensor]=None):
q = self.query(x)
k = self.key(xa or x)
v = self.value(xa or x)
wv, qk = self.qkv_attention(q, k, v, mask)
# NOTE: we aren't returning qk
return self.out(wv)
def qkv_attention(self, q, k, v, mask=None):
n_batch, n_ctx, n_state = q.shape
scale = (n_state // self.n_head) ** -0.25
q = q.reshape(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
k = k.reshape(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
v = v.reshape(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
qk = q @ k
if mask is not None: qk = qk + mask[:n_ctx, :n_ctx]
w = qk.softmax(-1)
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
class ResidualAttentionBlock:
def __init__(self, n_state, n_head, cross_attention=False):
self.attn = MultiHeadAttention(n_state, n_head)
self.attn_ln = nn.LayerNorm(n_state)
self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
self.cross_attn_ln = nn.LayerNorm(n_state) if cross_attention else None
self.mlp = [nn.Linear(n_state, n_state*4), Tensor.gelu, nn.Linear(n_state*4, n_state)]
self.mlp_ln = nn.LayerNorm(n_state)
def __call__(self, x, xa=None, mask=None):
x = x + self.attn(self.attn_ln(x), mask=mask)
if self.cross_attn: x = x + self.cross_attn(self.cross_attn_ln(x), xa)
x = x + self.mlp_ln(x).sequential(self.mlp)
return x
class AudioEncoder:
def __init__(self, n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer, **_):
self.conv1 = nn.Conv1d(n_mels, n_audio_state, kernel_size=3, padding=1)
self.conv2 = nn.Conv1d(n_audio_state, n_audio_state, kernel_size=3, stride=2, padding=1)
self.blocks = [ResidualAttentionBlock(n_audio_state, n_audio_head) for _ in range(n_audio_layer)]
self.ln_post = nn.LayerNorm(n_audio_state)
self.positional_embedding = Tensor.empty(n_audio_ctx, n_audio_state)
def __call__(self, x):
x = self.conv1(x).gelu()
x = self.conv2(x).gelu()
x = x.permute(0, 2, 1)
x = x + self.positional_embedding[:x.shape[1]]
x = x.sequential(self.blocks)
x = self.ln_post(x)
return x
class TextDecoder:
def __init__(self, n_vocab, n_text_ctx, n_text_state, n_text_head, n_text_layer, **_):
self.token_embedding = nn.Embedding(n_vocab, n_text_state)
self.positional_embedding = Tensor.empty(n_text_ctx, n_text_state)
self.blocks = [ResidualAttentionBlock(n_text_state, n_text_head, cross_attention=True) for _ in range(n_text_layer)]
self.ln = nn.LayerNorm(n_text_state)
#mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
def __call__(self, x, xa):
offset = 0
x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
seqlen, start_pos = x.shape[1], 0
mask = np.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=np.float32)
mask = np.triu(mask, k=start_pos + 1) # TODO: this is hard to do in tinygrad
mask = Tensor(mask)
for block in self.blocks: x = block(x, xa, mask)
x = self.ln(x)
return x @ self.token_embedding.weight.T
class Whisper:
def __init__(self, dims):
self.encoder = AudioEncoder(**dims)
self.decoder = TextDecoder(**dims)
def __call__(self, mel:Tensor, tokens:Tensor):
return self.decoder(tokens, self.encoder(mel))
RATE = 16000
CHUNK = 1600
RECORD_SECONDS = 10
def prep_audio(waveform) -> Tensor:
N_FFT = 400
HOP_LENGTH = 160
N_MELS = 80
assert waveform is not None
waveform = waveform.reshape(1, -1)
stft = librosa.stft(waveform, n_fft=N_FFT, hop_length=HOP_LENGTH, window='hann', dtype=np.float32)
magnitudes = stft[..., :-1] ** 2
mel_spec = librosa.filters.mel(sr=RATE, n_fft=N_FFT, n_mels=N_MELS) @ magnitudes
log_spec = np.log10(np.clip(mel_spec, 1e-10, mel_spec.max() + 1e8))
log_spec = (log_spec + 4.0) / 4.0
# https://github.com/openai/whisper/blob/b38a1f20f4b23f3f3099af2c3e0ca95627276ddf/whisper/audio.py#L19
n_frames = log_spec.shape[2]
if n_frames < 3000:
log_spec = np.pad(log_spec, ((0, 0), (0, 0), (0, 3000 - n_frames)))
#print(waveform.shape, log_spec.shape)
return log_spec
LANGUAGES = {
"en": "english", "zh": "chinese", "de": "german", "es": "spanish", "ru": "russian", "ko": "korean", "fr": "french", "ja": "japanese", "pt": "portuguese", "tr": "turkish",
"pl": "polish", "ca": "catalan", "nl": "dutch", "ar": "arabic", "sv": "swedish", "it": "italian", "id": "indonesian", "hi": "hindi", "fi": "finnish", "vi": "vietnamese",
"he": "hebrew", "uk": "ukrainian", "el": "greek", "ms": "malay", "cs": "czech", "ro": "romanian", "da": "danish", "hu": "hungarian", "ta": "tamil", "no": "norwegian",
"th": "thai", "ur": "urdu", "hr": "croatian", "bg": "bulgarian", "lt": "lithuanian", "la": "latin", "mi": "maori", "ml": "malayalam", "cy": "welsh", "sk": "slovak", "te": "telugu",
"fa": "persian", "lv": "latvian", "bn": "bengali", "sr": "serbian", "az": "azerbaijani", "sl": "slovenian", "kn": "kannada", "et": "estonian", "mk": "macedonian",
"br": "breton", "eu": "basque", "is": "icelandic", "hy": "armenian", "ne": "nepali", "mn": "mongolian", "bs": "bosnian", "kk": "kazakh", "sq": "albanian", "sw": "swahili",
"gl": "galician", "mr": "marathi", "pa": "punjabi", "si": "sinhala", "km": "khmer", "sn": "shona", "yo": "yoruba", "so": "somali", "af": "afrikaans", "oc": "occitan", "ka": "georgian",
"be": "belarusian", "tg": "tajik", "sd": "sindhi", "gu": "gujarati", "am": "amharic", "yi": "yiddish", "lo": "lao", "uz": "uzbek", "fo": "faroese", "ht": "haitian creole",
"ps": "pashto", "tk": "turkmen", "nn": "nynorsk", "mt": "maltese", "sa": "sanskrit", "lb": "luxembourgish", "my": "myanmar", "bo": "tibetan", "tl": "tagalog", "mg": "malagasy",
"as": "assamese", "tt": "tatar", "haw": "hawaiian", "ln": "lingala", "ha": "hausa", "ba": "bashkir", "jw": "javanese", "su": "sundanese",
}
BASE = pathlib.Path(__file__).parents[1] / "weights"
def get_encoding(n_vocab_in):
download_file("https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/gpt2.tiktoken", BASE / "gpt2.tiktoken")
ranks = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in open(BASE / "gpt2.tiktoken") if line)}
n_vocab = len(ranks)
specials = [
"<|endoftext|>",
"<|startoftranscript|>",
*[f"<|{lang}|>" for lang in LANGUAGES.keys()],
"<|translate|>",
"<|transcribe|>",
"<|startoflm|>",
"<|startofprev|>",
"<|nospeech|>",
"<|notimestamps|>",
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
]
special_tokens = dict(zip(specials, itertools.count(n_vocab)))
n_vocab += len(specials)
assert n_vocab == n_vocab_in
import tiktoken
return tiktoken.Encoding(
name="bob",
explicit_n_vocab=n_vocab,
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
mergeable_ranks=ranks,
special_tokens=special_tokens)
def img(x):
import matplotlib.pyplot as plt
plt.imshow(x.numpy())
plt.show()
def listener(q):
import pyaudio
p = pyaudio.PyAudio()
stream = p.open(format=pyaudio.paInt16, channels=1, rate=RATE, input=True, frames_per_buffer=CHUNK)
print("listening")
for _ in range(0, int(RATE / CHUNK * RECORD_SECONDS)):
data = stream.read(CHUNK)
waveform = ((np.frombuffer(data, np.int16)/32768).astype(np.float32)*3)
q.put(waveform)
print("done listening")
MODEL_URLS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
"large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
}
def init_whisper(model_name="tiny.en"):
assert MODEL_URLS[model_name] is not None
filename = BASE / "whisper-{}.pt".format(model_name)
download_file(MODEL_URLS[model_name], filename)
state = torch_load(filename)
model = Whisper(state['dims'])
load_state_dict(model, state['model_state_dict'])
enc = get_encoding(state['dims']['n_vocab'])
return model, enc
def transcribe_file(model, enc, filename):
waveform, sample_rate = librosa.load(filename, sr=RATE)
log_spec = prep_audio(waveform)
lst = [enc._special_tokens["<|startoftranscript|>"], enc._special_tokens["<|notimestamps|>"]]
dat = model.encoder(Tensor(log_spec)).realize()
for i in range(50):
out = model.decoder(Tensor([lst]), dat).realize()
idx = int(out[0,-1].argmax().numpy().item())
lst.append(idx)
transcription = enc.decode(lst)
print(transcription)
if lst[-1] == enc._special_tokens["<|endoftext|>"]:
return transcription
if __name__ == "__main__":
model, enc = init_whisper("small.en" if getenv("SMALL") else "tiny.en")
if len(sys.argv) > 1:
transcribe_file(model, enc, sys.argv[1])
else:
# online
q = multiprocessing.Queue()
p = multiprocessing.Process(target=listener, args=(q,))
p.daemon = True
p.start()
lst = [enc._special_tokens["<|startoftranscript|>"], enc._special_tokens["<|notimestamps|>"]]
total = None
did_read = False
for i in range(0, int(RATE / CHUNK * RECORD_SECONDS)):
while not q.empty() or total is None:
waveform = q.get()
if total is None: total = waveform
else: total = np.concatenate([total, waveform])
did_read = True
if did_read:
log_spec = prep_audio(total)
encoded_audio = model.encoder(Tensor(log_spec)).realize()
out = model.decoder(Tensor([lst]), encoded_audio).realize()
idx = int(out[0,-1].argmax().numpy().item())
lst.append(idx)
dec = enc.decode(lst)
print(dec) # DO NOT REMOVE PRINT. IT'S VERY IMPORTANT
if dec.endswith("<|endoftext|>"):
#total = total[:, 320*(len(lst)-1):]
lst.pop()