Skip to content
Draft
3 changes: 3 additions & 0 deletions python/quadrants/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from quadrants.types.enums import SNodeGradType
from quadrants.types.ndarray_type import NdarrayType
from quadrants.types.primitive_types import (
PrimitiveBase,
all_types,
f16,
f32,
Expand Down Expand Up @@ -111,6 +112,8 @@ def expr_init(rhs):
return dict((key, expr_init(val)) for key, val in rhs.items())
if isinstance(rhs, _qd_core.DataTypeCxx):
return rhs
if isinstance(rhs, type) and issubclass(rhs, PrimitiveBase):
return rhs.cxx
if isinstance(rhs, _qd_core.Arch):
return rhs
if isinstance(rhs, _Ndrange):
Expand Down
7 changes: 3 additions & 4 deletions python/quadrants/lang/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from quadrants.lang.expr import Expr
from quadrants.lang.impl import axes, get_runtime
from quadrants.profiler.kernel_profiler import get_default_kernel_profiler
from quadrants.types.primitive_types import f32, f64, i32, i64
from quadrants.types.primitive_types import PrimitiveMeta, f32, f64, i32, i64

warnings.filterwarnings("once", category=DeprecationWarning, module="quadrants")

Expand Down Expand Up @@ -314,15 +314,14 @@ def _install_python_backend_dtype_call():
return
_dtype_call_installed = True

DataTypeCxx = type(f32)
_original = DataTypeCxx.__call__
_original = PrimitiveMeta.__call__

def _dtype_call(self, value):
if impl.is_python_backend():
return float(value) if self in _FLOAT_DTYPES else int(value)
return _original(self, value)

DataTypeCxx.__call__ = _dtype_call # type: ignore[assignment]
PrimitiveMeta.__call__ = _dtype_call # type: ignore[assignment]


def init(
Expand Down
107 changes: 78 additions & 29 deletions python/quadrants/lang/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,59 @@
from quadrants.lang import impl
from quadrants.types import Template
from quadrants.types.primitive_types import (
PrimitiveBase,
all_types,
f16,
f16_cxx,
f32,
f32_cxx,
f64,
f64_cxx,
i8,
i8_cxx,
i16,
i16_cxx,
i32,
i32_cxx,
i64,
i64_cxx,
u1,
u1_cxx,
u8,
u8_cxx,
u16,
u16_cxx,
u32,
u32_cxx,
u64,
u64_cxx,
)

MAP_TYPE_IDS = {id(dtype): dtype for dtype in all_types}
MAP_TYPE_IDS: dict[int, Any] = {id(dtype): dtype for dtype in all_types}
_all_cxx_objs = (
f16_cxx,
f32_cxx,
f64_cxx,
i8_cxx,
i16_cxx,
i32_cxx,
i64_cxx,
u1_cxx,
u8_cxx,
u16_cxx,
u32_cxx,
u64_cxx,
)
for _cxx in _all_cxx_objs:
MAP_TYPE_IDS[id(_cxx)] = _cxx

# Pre-computed id-based cache for cook_dtype hot path.
# Maps id(Python class) and id(DataTypeCxx) to the DataTypeCxx result.
_cook_cache: dict[int, _qd_core.DataTypeCxx] = {}
for _cls in (f16, f32, f64, i8, i16, i32, i64, u1, u8, u16, u32, u64):
_cook_cache[id(_cls)] = _cls.cxx
for _cxx in _all_cxx_objs:
_cook_cache[id(_cxx)] = _cxx


def has_pytorch():
Expand Down Expand Up @@ -178,71 +215,74 @@ def to_quadrants_type(dt):
dt (DataType): The desired data type to convert.

Returns:
DataType: The counterpart data type in quadrants.
DataTypeCxx: The counterpart data type in quadrants (always returns DataTypeCxx).

"""
_type = type(dt)
if _type is int:
return MAP_TYPE_IDS[dt]
return cook_dtype(MAP_TYPE_IDS[dt])

if isinstance(dt, type) and issubclass(dt, PrimitiveBase):
return dt.cxx

if issubclass(_type, _qd_core.DataTypeCxx):
return dt

if dt == np.float32:
return f32
return f32.cxx
if dt == np.float64:
return f64
return f64.cxx
if dt == np.int32:
return i32
return i32.cxx
if dt == np.int64:
return i64
return i64.cxx
if dt == np.int8:
return i8
return i8.cxx
if dt == np.int16:
return i16
return i16.cxx
if dt == np.bool_:
return u1
return u1.cxx
if dt == np.uint8:
return u8
return u8.cxx
if dt == np.uint16:
return u16
return u16.cxx
if dt == np.uint32:
return u32
return u32.cxx
if dt == np.uint64:
return u64
return u64.cxx
if dt == np.half:
return f16
return f16.cxx

if has_pytorch():
import torch # pylint: disable=C0415

# pylint: disable=E1101
if dt == torch.float32:
return f32
return f32.cxx
if dt == torch.float64:
return f64
return f64.cxx
if dt == torch.int32:
return i32
return i32.cxx
if dt == torch.int64:
return i64
return i64.cxx
if dt == torch.int8:
return i8
return i8.cxx
if dt == torch.int16:
return i16
return i16.cxx
if dt == torch.bool:
return u1
return u1.cxx
if dt == torch.uint8:
return u8
return u8.cxx
if dt == torch.float16:
return f16
return f16.cxx

if hasattr(torch, "uint16"):
if dt == torch.uint16:
return u16
return u16.cxx
if dt == torch.uint32:
return u32
return u32.cxx
if dt == torch.uint64:
return u64
return u64.cxx

raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type before version 2.3.0.")

Expand All @@ -265,8 +305,17 @@ def __hash__(self):


def cook_dtype(dtype: Any) -> _qd_core.DataTypeCxx:
# Convert Python dtype to CPP dtype
"""Convert Python dtype to C++ DataTypeCxx.

Handles PrimitiveBase classes, raw DataTypeCxx instances, Type instances,
and Python builtins (float, int, bool). Uses id-based cache for hot paths.
"""
cached = _cook_cache.get(id(dtype))
if cached is not None:
return cached
_type = type(dtype)
if isinstance(dtype, type) and issubclass(dtype, PrimitiveBase):
return dtype.cxx
if issubclass(_type, _qd_core.DataTypeCxx):
return dtype
if issubclass(_type, _qd_core.Type):
Expand All @@ -276,7 +325,7 @@ def cook_dtype(dtype: Any) -> _qd_core.DataTypeCxx:
if dtype is int:
return impl.get_runtime().default_ip
if dtype is bool:
return u1
return u1.cxx
raise ValueError(f"Invalid data type {dtype}")


Expand Down
Loading
Loading