2025-04-19 08:05:49 +09:00

145 lines
4.9 KiB
C++

#include <ATen/detail/PrivateUse1HooksInterface.h>
#include <c10/core/impl/alloc_cpu.h>
#include <torch/extension.h>
#include <torch/csrc/PyInterpreter.h>
#include <ATen/OpaqueTensorImpl.h>
// register guard
namespace at {
namespace detail {
//C10_REGISTER_GUARD_IMPL(PrivateUse1, c10::impl::NoOpDeviceGuardImpl<DeviceType::PrivateUse1>);
// NOTE: pytorch's no-op class throws error on backwards with events/streams
// TODO: why are there events in autograd?
struct CustomNoOpDeviceGuardImpl : public c10::impl::DeviceGuardImplInterface
{
static const DeviceType D = DeviceType::PrivateUse1;
CustomNoOpDeviceGuardImpl() = default;
DeviceType type() const override {
return D;
}
Device exchangeDevice(Device) const override {
return Device(D, 0); // no-op
}
Device getDevice() const override {
return Device(D, 0);
}
void setDevice(Device) const override {
// no-op
}
void uncheckedSetDevice(Device) const noexcept override {
// no-op
}
Stream getStream(Device) const noexcept override {
// no-op
return Stream(Stream::DEFAULT, Device(D, 0));
}
Stream getDefaultStream(Device) const override {
// no-op
return Stream(Stream::DEFAULT, Device(D, 0));
}
Stream getStreamFromGlobalPool(Device, bool isHighPriority = false)
const override {
// no-op
(void)isHighPriority;
return Stream(Stream::DEFAULT, Device(D, 0));
}
Stream getNewStream(Device, int priority = 0) const override {
// no-op
(void)priority;
return Stream(Stream::DEFAULT, Device(D, 0));
}
// NB: These do NOT set the current device
Stream exchangeStream(Stream) const noexcept override {
// no-op
return Stream(Stream::DEFAULT, Device(D, 0));
}
DeviceIndex deviceCount() const noexcept override {
return 1;
}
// Event-related functions
void record(
void** /*event*/,
const Stream& /*stream*/,
const DeviceIndex /*device_index*/,
const EventFlag /*flag*/) const override {
//TORCH_CHECK(false, D, " backend doesn't support events.");
}
void block(void* /*event*/, const Stream& /*stream*/) const override {
//TORCH_CHECK(false, D, " backend doesn't support events.")
}
bool queryEvent(void* /*event*/) const override {
//TORCH_CHECK(false, D, " backend doesn't support events.")
return true;
}
void destroyEvent(void* /*event*/, const DeviceIndex /*device_index*/)
const noexcept override {}
// Stream-related functions
bool queryStream(const Stream& /*stream*/) const override {
return true;
}
void synchronizeStream(const Stream& /*stream*/) const override {
// Don't wait for anything.
}
};
C10_REGISTER_GUARD_IMPL(PrivateUse1, CustomNoOpDeviceGuardImpl);
}
template <typename OpaqueHandle>
struct TinyOpaqueTensorImpl : public OpaqueTensorImpl<OpaqueHandle> {
TinyOpaqueTensorImpl(
at::DispatchKeySet key_set,
const caffe2::TypeMeta data_type,
c10::Device device,
OpaqueHandle opaque_handle,
c10::IntArrayRef sizes,
c10::IntArrayRef strides,
int64_t storage_offset)
: OpaqueTensorImpl<OpaqueHandle>(key_set, data_type, device, opaque_handle, sizes) {
this->sizes_and_strides_.set_strides(strides);
this->storage_offset_ = storage_offset;
}
};
}
struct OpenRegHooksInterface : public at::PrivateUse1HooksInterface {
// NOTE: no idea what this is
bool hasPrimaryContext(c10::DeviceIndex device_index) const override { return true; }
};
int register_hook() {
at::RegisterPrivateUse1HooksInterface(new OpenRegHooksInterface());
return 0;
}
int temp_register_hook = register_hook();
at::Tensor wrap_tensor(py::object &py_obj, c10::ScalarType dtype, c10::DeviceIndex device_index) {
// TODO: we have to get the dtype and the shape from the tinygrad Tensor
std::vector<int64_t> sizes = py_obj.attr("shape").cast<std::vector<int64_t>>();
py::list views = py_obj.attr("lazydata").attr("st").attr("views");
std::vector<int64_t> strides = views[views.size() - 1].attr("strides").cast<std::vector<int64_t>>();
int64_t storage_offset = 0;
for (auto& v: views) {
storage_offset += v.attr("offset").cast<int64_t>(); // TODO: is this correct?
}
return at::detail::make_tensor<at::TinyOpaqueTensorImpl<std::shared_ptr<c10::SafePyObject>>>(
at::DispatchKeySet(at::DispatchKey::PrivateUse1),
c10::scalarTypeToTypeMeta(dtype),
at::Device(at::kPrivateUse1, device_index),
std::make_shared<c10::SafePyObject>(py_obj.release().ptr(), getPyInterpreter()),
sizes, strides, storage_offset);
}
py::object unwrap_tensor(const at::Tensor &tensor) {
auto* impl = tensor.unsafeGetTensorImpl();
auto* opaque_impl = static_cast<at::TinyOpaqueTensorImpl<std::shared_ptr<c10::SafePyObject>>*>(impl);
std::shared_ptr<c10::SafePyObject> tiny = opaque_impl->opaque_handle();
return py::reinterpret_borrow<py::object>(tiny->ptr(getPyInterpreter()));
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("wrap", &wrap_tensor);
m.def("unwrap", &unwrap_tensor);
}