#include #include #include #include #include // register guard namespace at { namespace detail { //C10_REGISTER_GUARD_IMPL(PrivateUse1, c10::impl::NoOpDeviceGuardImpl); // NOTE: pytorch's no-op class throws error on backwards with events/streams // TODO: why are there events in autograd? struct CustomNoOpDeviceGuardImpl : public c10::impl::DeviceGuardImplInterface { static const DeviceType D = DeviceType::PrivateUse1; CustomNoOpDeviceGuardImpl() = default; DeviceType type() const override { return D; } Device exchangeDevice(Device) const override { return Device(D, 0); // no-op } Device getDevice() const override { return Device(D, 0); } void setDevice(Device) const override { // no-op } void uncheckedSetDevice(Device) const noexcept override { // no-op } Stream getStream(Device) const noexcept override { // no-op return Stream(Stream::DEFAULT, Device(D, 0)); } Stream getDefaultStream(Device) const override { // no-op return Stream(Stream::DEFAULT, Device(D, 0)); } Stream getStreamFromGlobalPool(Device, bool isHighPriority = false) const override { // no-op (void)isHighPriority; return Stream(Stream::DEFAULT, Device(D, 0)); } Stream getNewStream(Device, int priority = 0) const override { // no-op (void)priority; return Stream(Stream::DEFAULT, Device(D, 0)); } // NB: These do NOT set the current device Stream exchangeStream(Stream) const noexcept override { // no-op return Stream(Stream::DEFAULT, Device(D, 0)); } DeviceIndex deviceCount() const noexcept override { return 1; } // Event-related functions void record( void** /*event*/, const Stream& /*stream*/, const DeviceIndex /*device_index*/, const EventFlag /*flag*/) const override { //TORCH_CHECK(false, D, " backend doesn't support events."); } void block(void* /*event*/, const Stream& /*stream*/) const override { //TORCH_CHECK(false, D, " backend doesn't support events.") } bool queryEvent(void* /*event*/) const override { //TORCH_CHECK(false, D, " backend doesn't support events.") return true; } void destroyEvent(void* /*event*/, const DeviceIndex /*device_index*/) const noexcept override {} // Stream-related functions bool queryStream(const Stream& /*stream*/) const override { return true; } void synchronizeStream(const Stream& /*stream*/) const override { // Don't wait for anything. } }; C10_REGISTER_GUARD_IMPL(PrivateUse1, CustomNoOpDeviceGuardImpl); } template struct TinyOpaqueTensorImpl : public OpaqueTensorImpl { TinyOpaqueTensorImpl( at::DispatchKeySet key_set, const caffe2::TypeMeta data_type, c10::Device device, OpaqueHandle opaque_handle, c10::IntArrayRef sizes, c10::IntArrayRef strides, int64_t storage_offset) : OpaqueTensorImpl(key_set, data_type, device, opaque_handle, sizes) { this->sizes_and_strides_.set_strides(strides); this->storage_offset_ = storage_offset; } }; } struct OpenRegHooksInterface : public at::PrivateUse1HooksInterface { // NOTE: no idea what this is bool hasPrimaryContext(c10::DeviceIndex device_index) const override { return true; } }; int register_hook() { at::RegisterPrivateUse1HooksInterface(new OpenRegHooksInterface()); return 0; } int temp_register_hook = register_hook(); at::Tensor wrap_tensor(py::object &py_obj, c10::ScalarType dtype, c10::DeviceIndex device_index) { // TODO: we have to get the dtype and the shape from the tinygrad Tensor std::vector sizes = py_obj.attr("shape").cast>(); py::list views = py_obj.attr("lazydata").attr("st").attr("views"); std::vector strides = views[views.size() - 1].attr("strides").cast>(); int64_t storage_offset = 0; for (auto& v: views) { storage_offset += v.attr("offset").cast(); // TODO: is this correct? } return at::detail::make_tensor>>( at::DispatchKeySet(at::DispatchKey::PrivateUse1), c10::scalarTypeToTypeMeta(dtype), at::Device(at::kPrivateUse1, device_index), std::make_shared(py_obj.release().ptr(), getPyInterpreter()), sizes, strides, storage_offset); } py::object unwrap_tensor(const at::Tensor &tensor) { auto* impl = tensor.unsafeGetTensorImpl(); auto* opaque_impl = static_cast>*>(impl); std::shared_ptr tiny = opaque_impl->opaque_handle(); return py::reinterpret_borrow(tiny->ptr(getPyInterpreter())); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("wrap", &wrap_tensor); m.def("unwrap", &unwrap_tensor); }