-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbitsqueeze_pytorch.cpp
More file actions
148 lines (123 loc) · 5.24 KB
/
bitsqueeze_pytorch.cpp
File metadata and controls
148 lines (123 loc) · 5.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#include <torch/extension.h>
#include <pybind11/stl.h>
#include <vector>
#include <string>
#include <optional>
#include <stdexcept>
#include <iostream>
extern "C" {
#include "include/bitsqueeze.h"
}
bsq_method_t get_method_from_string(const std::string& name) {
if (name == "Q8_0") return Q8_0;
if (name == "Q4_0") return Q4_0;
if (name == "Q2_K") return Q2_K;
if (name == "TOPK") return TOPK;
if (name == "BF16") return BF16;
if (name == "FP16") return FP16;
if (name == "FP8") return FP8;
if (name == "FP4") return FP4;
if (name == "MXFP8") return MXFP8;
if (name == "MXFP4") return MXFP4;
if (name == "NVFP4") return NVFP4;
if (name == "NF4_DQ") return NF4_DQ;
if (name == "NF4") return NF4;
if (name == "IQ2_XXS") return IQ2_XXS;
if (name == "IQ2_XS") return IQ2_XS;
if (name == "IQ2_S") return IQ2_S;
if (name == "Q2_K_FAST") return Q2_K_FAST;
if (name == "TOPK_IM") return TOPK_IM;
throw std::runtime_error("Unknown BitSqueeze method: " + name);
}
class BitSqueezeBufferWrapper {
public:
bitsqueeze_buffer_t* buf = nullptr;
std::vector<int64_t> original_shape;
BitSqueezeBufferWrapper() {}
~BitSqueezeBufferWrapper() {
if (buf) {
bsq_free(buf);
buf = nullptr;
}
}
BitSqueezeBufferWrapper(const BitSqueezeBufferWrapper&) = delete;
BitSqueezeBufferWrapper& operator=(const BitSqueezeBufferWrapper&) = delete;
BitSqueezeBufferWrapper(BitSqueezeBufferWrapper&& other) noexcept : buf(other.buf) { other.buf = nullptr; }
static std::unique_ptr<BitSqueezeBufferWrapper> compress(
torch::Tensor input,
std::string method_name,
float sparse_ratio = 0.0f,
std::optional<torch::Tensor> importance = std::nullopt)
{
TORCH_CHECK(input.is_contiguous(), "Input tensor must be contiguous");
TORCH_CHECK(input.scalar_type() == torch::kFloat32, "Input must be Float32");
TORCH_CHECK(input.device().is_cpu(), "Input must be on CPU");
bsq_method_t method = get_method_from_string(method_name);
auto wrapper = std::make_unique<BitSqueezeBufferWrapper>();
wrapper->original_shape = input.sizes().vec();
const float* src_ptr = input.data_ptr<float>();
const float* im_ptr = nullptr;
if (importance.has_value() && importance->defined()) {
torch::Tensor& im = *importance;
TORCH_CHECK(im.is_contiguous(), "Importance tensor must be contiguous");
TORCH_CHECK(im.scalar_type() == torch::kFloat32, "Importance must be Float32");
TORCH_CHECK(im.sizes() == input.sizes(), "Importance shape must match input");
im_ptr = im.data_ptr<float>();
}
int ret = 0;
if (method == TOPK || method == TOPK_IM) {
TORCH_CHECK(input.dim() == 2, "Sparse methods (TOPK, TOPK_IM) require 2D input");
uint16_t num_tokens = (uint16_t)input.size(0);
uint16_t num_features = (uint16_t)input.size(1);
ret = bsq_compress_2d(src_ptr, num_tokens, num_features, sparse_ratio, method, &wrapper->buf, im_ptr);
} else {
uint64_t num_elements = (uint64_t)input.numel();
ret = bsq_compress_1d(src_ptr, num_elements, method, &wrapper->buf, im_ptr);
}
TORCH_CHECK(ret == 0, "BitSqueeze compression failed");
return wrapper;
}
torch::Tensor decompress() {
TORCH_CHECK(buf != nullptr, "Buffer is empty/null");
std::vector<int64_t> shape;
if (buf->method == TOPK || buf->method == TOPK_IM) {
shape.push_back(buf->shape.num_tokens);
shape.push_back(buf->shape.num_features);
} else {
shape.push_back(buf->shape.num_elements);
}
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCPU);
torch::Tensor output = torch::empty(shape, options);
int ret = bsq_decompress(buf, output.data_ptr<float>(), (uint64_t)output.numel());
TORCH_CHECK(ret == 0, "BitSqueeze decompression failed");
if (!original_shape.empty()) {
int64_t total_elements = 1;
for(auto d : original_shape) total_elements *= d;
if (total_elements == output.numel()) {
return output.view(original_shape);
}
}
return output;
}
int64_t get_packed_size() {
if (!buf) return 0;
return (int64_t)bsq_get_packed_size(buf);
}
std::string get_method_name() {
if (!buf) return "INVALID";
return std::to_string(buf->method);
}
};
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::class_<BitSqueezeBufferWrapper>(m, "BitSqueezeBuffer")
.def(py::init<>())
.def_static("compress", &BitSqueezeBufferWrapper::compress,
py::arg("input"),
py::arg("method"),
py::arg("sparse_ratio") = 0.0f,
py::arg("importance") = py::none(),
"Compress a float32 tensor into a BitSqueezeBuffer")
.def("decompress", &BitSqueezeBufferWrapper::decompress)
.def_property_readonly("size", &BitSqueezeBufferWrapper::get_packed_size)
.def_property_readonly("method", &BitSqueezeBufferWrapper::get_method_name);
}