512 lines
19 KiB
HTML
512 lines
19 KiB
HTML
<!DOCTYPE html>
|
|
<html lang="en">
|
|
|
|
<head>
|
|
<meta charset="UTF-8">
|
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
|
<title>tinygrad has WebGPU</title>
|
|
|
|
<style>
|
|
body {
|
|
font-family: 'Arial', sans-serif;
|
|
text-align: center;
|
|
padding: 30px;
|
|
}
|
|
|
|
a {
|
|
text-decoration: none;
|
|
color: #4A90E2;
|
|
}
|
|
|
|
h1 {
|
|
font-size: 36px;
|
|
font-weight: normal;
|
|
margin-bottom: 20px;
|
|
}
|
|
|
|
#mybox {
|
|
display: flex;
|
|
flex-direction: column;
|
|
align-items: center;
|
|
gap: 20px;
|
|
width: 50%;
|
|
margin: 0 auto;
|
|
}
|
|
|
|
#promptText, #stepRange, #btnRunNet, #guidanceRange {
|
|
font-size: 18px;
|
|
width: 100%;
|
|
}
|
|
|
|
#result {
|
|
font-size: 48px;
|
|
}
|
|
|
|
#time {
|
|
font-size: 16px;
|
|
color: grey;
|
|
}
|
|
|
|
canvas {
|
|
margin-top: 20px;
|
|
border: 1px solid #000;
|
|
}
|
|
|
|
label {
|
|
display: flex;
|
|
align-items: center;
|
|
gap: 10px;
|
|
width: 100%;
|
|
}
|
|
|
|
#sliderValue {
|
|
margin-right: 10px;/
|
|
}
|
|
</style>
|
|
|
|
<script type="module">
|
|
import ClipTokenizer from 'https://softwired.nyc3.cdn.digitaloceanspaces.com/sd/clip_tokenizer.js';
|
|
window.clipTokenizer = new ClipTokenizer();
|
|
</script>
|
|
<script src="./f16_to_f32.js"></script>
|
|
<script src="./net.js"></script>
|
|
</head>
|
|
|
|
<body>
|
|
<h1 id="wgpuError" style="display: none; color: red;">WebGPU is not supported in this browser</h1>
|
|
<h1 id="sdTitle">StableDiffusion by <a href="https://github.com/tinygrad/tinygrad" target="_blank">tinygrad</a> WebGPU</h1>
|
|
<div id="mybox">
|
|
<input id="promptText" type="text" placeholder="Enter your prompt here" value="a horse sized cat eating a bagel">
|
|
|
|
<label>
|
|
Steps: <span id="stepValue">8</span>
|
|
<input id="stepRange" type="range" min="5" max="20" value="8" step="1">
|
|
</label>
|
|
|
|
<label>
|
|
Guidance: <span id="guidanceValue">7.5</span>
|
|
<input id="guidanceRange" type="range" min="3" max="15" value="7.5" step="0.1">
|
|
</label>
|
|
|
|
<input id="btnRunNet" type="button" value="Run" disabled>
|
|
|
|
<div id="divModelDl" style="display: flex; align-items: center; width: 100%; gap: 10px;">
|
|
<span id="modelDlTitle">Downloading model</span>
|
|
<progress id="modelDlProgressBar" value="0" max="100" style="flex-grow: 1;"></progress>
|
|
<span id="modelDlProgressValue"></span>
|
|
</div>
|
|
|
|
<div id="divStepProgress" style="display: none; align-items: center; width: 100%; gap: 10px;">
|
|
<progress id="progressBar" value="0" max="100" style="flex-grow: 1;"></progress>
|
|
<span id="progressFraction"></span>
|
|
</div>
|
|
</div>
|
|
<canvas id="canvas" width="512" height="512"></canvas>
|
|
|
|
<script>
|
|
function initDb() {
|
|
return new Promise((resolve, reject) => {
|
|
let db;
|
|
const request = indexedDB.open('tinydb', 1);
|
|
request.onerror = (event) => {
|
|
console.error('Database error:', event.target.error);
|
|
resolve(null);
|
|
};
|
|
|
|
request.onsuccess = (event) => {
|
|
db = event.target.result;
|
|
console.log("Db initialized.");
|
|
resolve(db);
|
|
};
|
|
|
|
request.onupgradeneeded = (event) => {
|
|
db = event.target.result;
|
|
if (!db.objectStoreNames.contains('tensors')) {
|
|
db.createObjectStore('tensors', { keyPath: 'id' });
|
|
}
|
|
};
|
|
});
|
|
}
|
|
|
|
function saveTensorToDb(db, id, tensor) {
|
|
return new Promise((resolve, reject) => {
|
|
if (db == null) {
|
|
resolve(null);
|
|
}
|
|
|
|
const transaction = db.transaction(['tensors'], 'readwrite');
|
|
const store = transaction.objectStore('tensors');
|
|
const request = store.put({ id: id, content: tensor });
|
|
|
|
transaction.onabort = (event) => {
|
|
console.log("Transaction error while saving tensor: " + event.target.error);
|
|
resolve(null);
|
|
};
|
|
|
|
request.onsuccess = () => {
|
|
console.log('Tensor saved successfully.');
|
|
resolve();
|
|
};
|
|
|
|
request.onerror = (event) => {
|
|
console.error('Tensor save failed:', event.target.error);
|
|
resolve(null);
|
|
};
|
|
});
|
|
}
|
|
|
|
function readTensorFromDb(db, id) {
|
|
return new Promise((resolve, reject) => {
|
|
if (db == null) {
|
|
resolve(null);
|
|
}
|
|
|
|
const transaction = db.transaction(['tensors'], 'readonly');
|
|
const store = transaction.objectStore('tensors');
|
|
const request = store.get(id);
|
|
|
|
transaction.onabort = (event) => {
|
|
console.log("Transaction error while reading tensor: " + event.target.error);
|
|
resolve(null);
|
|
};
|
|
|
|
request.onsuccess = (event) => {
|
|
const result = event.target.result;
|
|
if (result) {
|
|
console.log("Cache hit: " + id);
|
|
resolve(result);
|
|
} else {
|
|
console.log("Cache miss: " + id);
|
|
resolve(null);
|
|
}
|
|
};
|
|
|
|
request.onerror = (event) => {
|
|
console.error('Tensor retrieve failed: ', event.target.error);
|
|
resolve(null);
|
|
};
|
|
});
|
|
}
|
|
|
|
window.addEventListener('load', async function() {
|
|
if (!navigator.gpu) {
|
|
document.getElementById("wgpuError").style.display = "";
|
|
document.getElementById("sdTitle").style.display = "none";
|
|
return;
|
|
}
|
|
|
|
let db = await initDb();
|
|
|
|
const ctx = document.getElementById("canvas").getContext("2d", { willReadFrequently: true });
|
|
let labels, nets, safetensorParts;
|
|
|
|
const getDevice = async () => {
|
|
const adapter = await navigator.gpu.requestAdapter();
|
|
const requiredLimits = {};
|
|
const maxBufferSizeInSDModel = 1073741824;
|
|
requiredLimits.maxStorageBufferBindingSize = maxBufferSizeInSDModel;
|
|
requiredLimits.maxBufferSize = maxBufferSizeInSDModel;
|
|
|
|
return await adapter.requestDevice({
|
|
requiredLimits
|
|
});
|
|
};
|
|
|
|
const timer = async (func, label = "") => {
|
|
const start = performance.now();
|
|
const out = await func();
|
|
const delta = (performance.now() - start).toFixed(1)
|
|
console.log(`${delta} ms ${label}`);
|
|
return out;
|
|
}
|
|
|
|
const getProgressDlForPart = async (part, progressCallback) => {
|
|
const response = await fetch(part);
|
|
const contentLength = response.headers.get('content-length');
|
|
const total = parseInt(contentLength, 10);
|
|
|
|
const res = new Response(new ReadableStream({
|
|
async start(controller) {
|
|
const reader = response.body.getReader();
|
|
for (;;) {
|
|
const { done, value } = await reader.read();
|
|
if (done) break;
|
|
progressCallback(part, value.byteLength, total);
|
|
controller.enqueue(value);
|
|
}
|
|
|
|
controller.close();
|
|
},
|
|
}));
|
|
|
|
return res.arrayBuffer();
|
|
};
|
|
|
|
const getAndDecompressF16Safetensors = async (device, progress) => {
|
|
let totalLoaded = 0;
|
|
let totalSize = 0;
|
|
let partSize = {};
|
|
|
|
const progressCallback = (part, loaded, total) => {
|
|
totalLoaded += loaded;
|
|
|
|
if (!partSize[part]) {
|
|
totalSize += total;
|
|
partSize[part] = true;
|
|
}
|
|
|
|
progress(totalLoaded, totalSize);
|
|
};
|
|
|
|
let combinedBuffer = await readTensorFromDb(db, "net.f16");
|
|
let textModelU8 = await readTensorFromDb(db, "net.text");
|
|
let textModelFetched = false;
|
|
|
|
if (combinedBuffer == null) {
|
|
let dlParts = [
|
|
getProgressDlForPart(window.MODEL_BASE_URL + '/net_part0.safetensors', progressCallback),
|
|
getProgressDlForPart(window.MODEL_BASE_URL + '/net_part1.safetensors', progressCallback),
|
|
getProgressDlForPart(window.MODEL_BASE_URL + '/net_part2.safetensors', progressCallback),
|
|
getProgressDlForPart(window.MODEL_BASE_URL + '/net_part3.safetensors', progressCallback)
|
|
];
|
|
|
|
if (textModelU8 == null) {
|
|
dlParts.push(getProgressDlForPart(window.MODEL_BASE_URL + '/net_textmodel.safetensors', progressCallback));
|
|
}
|
|
|
|
let buffers = await Promise.all(dlParts);
|
|
|
|
// Combine everything except for text model, since that's alreafy f32
|
|
const totalLength = buffers.reduce((acc, buffer, index, array) => {
|
|
if (index < 4) {
|
|
return acc + buffer.byteLength;
|
|
} else {
|
|
return acc;
|
|
}
|
|
}, 0
|
|
);
|
|
|
|
combinedBuffer = new Uint8Array(totalLength);
|
|
|
|
let offset = 0;
|
|
buffers.forEach((buffer, index) => {
|
|
if (index < 4) {
|
|
combinedBuffer.set(new Uint8Array(buffer), offset);
|
|
offset += buffer.byteLength;
|
|
buffer = null;
|
|
}
|
|
});
|
|
|
|
await saveTensorToDb(db, "net.f16", combinedBuffer);
|
|
|
|
if (textModelU8 == null) {
|
|
textModelFetched = true;
|
|
textModelU8 = new Uint8Array(buffers[4]);
|
|
await saveTensorToDb(db, "net.text", textModelU8);
|
|
}
|
|
} else {
|
|
combinedBuffer = combinedBuffer.content;
|
|
}
|
|
|
|
if (textModelU8 == null) {
|
|
textModelU8 = new Uint8Array(await getProgressDlForPart(window.MODEL_BASE_URL + '/net_textmodel.safetensors', progressCallback));
|
|
await saveTensorToDb(db, "net.text", textModelU8);
|
|
} else if (!textModelFetched) {
|
|
textModelU8 = textModelU8.content;
|
|
}
|
|
|
|
document.getElementById("modelDlTitle").innerHTML = "Decompressing model";
|
|
|
|
const textModelOffset = 3772703308;
|
|
const metadataLength = Number(new DataView(combinedBuffer.buffer).getBigUint64(0, true));
|
|
const metadata = JSON.parse(new TextDecoder("utf8").decode(combinedBuffer.subarray(8, 8 + metadataLength)));
|
|
|
|
const allToDecomp = combinedBuffer.byteLength - (8 + metadataLength);
|
|
const decodeChunkSize = 67107840;
|
|
const numChunks = Math.ceil(allToDecomp/decodeChunkSize);
|
|
|
|
console.log(allToDecomp + " bytes to decompress");
|
|
console.log("Will be decompressed in " + numChunks+ " chunks");
|
|
|
|
let partOffsets = [{start: 0, end: 1131408336}, {start: 1131408336, end: 2227518416}, {start: 2227518416, end: 3308987856}, {start: 3308987856, end: 4265298864}];
|
|
let parts = [];
|
|
|
|
for (let offsets of partOffsets) {
|
|
parts.push(new Uint8Array(offsets.end-offsets.start));
|
|
}
|
|
parts[0].set(new Uint8Array(new BigUint64Array([BigInt(metadataLength)]).buffer), 0);
|
|
parts[0].set(combinedBuffer.subarray(8, 8 + metadataLength), 8);
|
|
parts[3].set(textModelU8, textModelOffset+8+metadataLength - partOffsets[3].start);
|
|
|
|
let start = Date.now();
|
|
let cursor = 0;
|
|
|
|
for (let i = 0; i < numChunks; i++) {
|
|
progress(i, numChunks);
|
|
let chunkStartF16 = 8 + metadataLength + (decodeChunkSize * i);
|
|
let chunkEndF16 = chunkStartF16 + decodeChunkSize;
|
|
let chunk = combinedBuffer.subarray(chunkStartF16, chunkEndF16);
|
|
|
|
if (chunk.byteLength %4 != 0) {
|
|
const paddingBytes = 4 - (chunk.byteLength % 4);
|
|
const alignedBuffer = new ArrayBuffer(chunk.byteLength + paddingBytes);
|
|
const alignedView = new Uint8Array(alignedBuffer);
|
|
alignedView.set(new Uint8Array(chunk));
|
|
chunk = alignedView;
|
|
}
|
|
|
|
let result = await f16tof32GPU(device, chunk);
|
|
let resultUint8 = new Uint8Array(result.buffer);
|
|
let chunkStartF32 = 8 + metadataLength + (decodeChunkSize * i * 2);
|
|
let chunkEndF32 = chunkStartF32 + resultUint8.byteLength;
|
|
let offsetInPart = chunkStartF32 - partOffsets[cursor].start;
|
|
|
|
if (chunkEndF32 < partOffsets[cursor].end || cursor === parts.length - 1) {
|
|
parts[cursor].set(resultUint8, offsetInPart);
|
|
} else {
|
|
let spaceLeftInCurrentPart = partOffsets[cursor].end - chunkStartF32;
|
|
parts[cursor].set(resultUint8.subarray(0, spaceLeftInCurrentPart), offsetInPart);
|
|
|
|
cursor++;
|
|
|
|
if (cursor < parts.length) {
|
|
let nextPartOffset = spaceLeftInCurrentPart;
|
|
let nextPartLength = resultUint8.length - nextPartOffset;
|
|
parts[cursor].set(resultUint8.subarray(nextPartOffset, nextPartOffset + nextPartLength), 0);
|
|
}
|
|
}
|
|
|
|
resultUint8 = null;
|
|
result = null;
|
|
}
|
|
|
|
combinedBuffer = null;
|
|
|
|
let end = Date.now();
|
|
console.log("Decoding took: " + ((end - start) / 1000) + " s");
|
|
console.log("Avarage " + ((end - start) / numChunks) + " ms per chunk");
|
|
|
|
return parts;
|
|
};
|
|
|
|
const loadNet = async () => {
|
|
const modelDlTitle = document.getElementById("modelDlTitle");
|
|
|
|
const progress = (loaded, total) => {
|
|
document.getElementById("modelDlProgressBar").value = (loaded/total) * 100
|
|
document.getElementById("modelDlProgressValue").innerHTML = Math.trunc((loaded/total) * 100) + "%"
|
|
}
|
|
|
|
const device = await getDevice();
|
|
safetensorParts = await getAndDecompressF16Safetensors(device, progress);
|
|
|
|
modelDlTitle.innerHTML = "Compiling model"
|
|
|
|
let models = ["textModel", "diffusor", "decoder"];
|
|
|
|
nets = await timer(() => Promise.all([
|
|
textModel().setup(device, safetensorParts),
|
|
diffusor().setup(device, safetensorParts),
|
|
decoder().setup(device, safetensorParts)
|
|
]).then((loadedModels) => loadedModels.reduce((acc, model, index) => { acc[models[index]] = model; return acc; }, {})), "(compilation)")
|
|
|
|
progress(1, 1);
|
|
|
|
modelDlTitle.innerHTML = "Model ready"
|
|
setTimeout(() => {
|
|
document.getElementById("modelDlProgressBar").style.display = "none";
|
|
document.getElementById("modelDlProgressValue").style.display = "none";
|
|
document.getElementById("divStepProgress").style.display = "flex";
|
|
}, 1000);
|
|
document.getElementById("btnRunNet").disabled = false;
|
|
}
|
|
|
|
function runStableDiffusion(prompt, steps, guidance) {
|
|
return new Promise(async (resolve, reject) => {
|
|
let context = await timer(() => nets["textModel"](clipTokenizer.encodeForCLIP(prompt)));
|
|
let unconditional_context = await timer(() => nets["textModel"](clipTokenizer.encodeForCLIP("")));
|
|
|
|
let timesteps = [];
|
|
|
|
for (let i = 1; i < 1000; i += (1000/steps)) {
|
|
timesteps.push(i);
|
|
}
|
|
|
|
console.log("Timesteps: " + timesteps);
|
|
|
|
let alphasCumprod = getWeight(safetensorParts,"alphas_cumprod");
|
|
let alphas = [];
|
|
|
|
for (t of timesteps) {
|
|
alphas.push(alphasCumprod[Math.floor(t)]);
|
|
}
|
|
|
|
alphas_prev = [1.0];
|
|
|
|
for (let i = 0; i < alphas.length-1; i++) {
|
|
alphas_prev.push(alphas[i]);
|
|
}
|
|
|
|
let inpSize = 4*64*64;
|
|
latent = new Float32Array(inpSize);
|
|
|
|
for (let i = 0; i < inpSize; i++) {
|
|
latent[i] = Math.sqrt(-2.0 * Math.log(Math.random())) * Math.cos(2.0 * Math.PI * Math.random());
|
|
}
|
|
|
|
for (let i = timesteps.length - 1; i >= 0; i--) {
|
|
let timestep = new Float32Array([timesteps[i]]);
|
|
let x_prev = await timer(() => nets["diffusor"](unconditional_context, context, latent, timestep, new Float32Array([alphas[i]]), new Float32Array([alphas_prev[i]]), new Float32Array([guidance])));
|
|
latent = x_prev;
|
|
document.getElementById("progressBar").value = ((steps - i) / steps) * 100
|
|
document.getElementById("progressFraction").innerHTML = (steps - i) + "/" + steps
|
|
}
|
|
|
|
resolve(await timer(() => nets["decoder"](latent)));
|
|
});
|
|
}
|
|
|
|
document.getElementById("btnRunNet").addEventListener("click", function(e) {
|
|
e.target.disabled = true;
|
|
|
|
runStableDiffusion(document.getElementById("promptText").value, document.getElementById("stepRange").value, document.getElementById("guidanceRange").value).then((image) => {
|
|
let pixels = []
|
|
let pixelCounter = 0
|
|
|
|
for (let j = 0; j < 512; j++) {
|
|
for (let k = 0; k < 512; k++) {
|
|
pixels.push(image[pixelCounter])
|
|
pixels.push(image[pixelCounter+1])
|
|
pixels.push(image[pixelCounter+2])
|
|
pixels.push(255)
|
|
pixelCounter += 3
|
|
}
|
|
}
|
|
|
|
ctx.putImageData(new ImageData(new Uint8ClampedArray(pixels), 512, 512), 0, 0);
|
|
console.log(image);
|
|
console.log("Success");
|
|
e.target.disabled = false;
|
|
});
|
|
}, false);
|
|
|
|
const stepSlider = document.getElementById('stepRange');
|
|
const stepValue = document.getElementById('stepValue');
|
|
|
|
stepSlider.addEventListener('input', function() {
|
|
stepValue.textContent = stepSlider.value;
|
|
});
|
|
|
|
const guidanceSlider = document.getElementById('guidanceRange');
|
|
const guidanceValue = document.getElementById('guidanceValue');
|
|
|
|
guidanceSlider.addEventListener('input', function() {
|
|
guidanceValue.textContent = guidanceSlider.value;
|
|
});
|
|
|
|
loadNet();
|
|
});
|
|
</script>
|
|
</body>
|
|
</html>
|