Skip to content

Commit 52c168a

Browse files
committed
[ML] Add quantized model ops to pytorch_inference allowlist
Add aten::mul_ and quantized::linear_dynamic to the allowed operations list, fixing validation failures for dynamically quantized models such as ELSER v2 when imported via Eland with torch.quantization.quantize_dynamic. Also update the model extraction tooling to support a "quantize" flag in reference_models.json so that quantized variants are traced with dynamic quantization applied before graph extraction, mirroring the Eland import pipeline. Made-with: Cursor
1 parent c40b317 commit 52c168a

File tree

7 files changed

+135
-20
lines changed

7 files changed

+135
-20
lines changed

bin/pytorch_inference/CSupportedOperations.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::FORBIDDEN_OPERA
3939
// elastic/test-elser-v2.
4040
// Additional ops from Elasticsearch integration test models
4141
// (PyTorchModelIT, TextExpansionQueryIT, TextEmbeddingQueryIT).
42+
// Quantized operations from dynamically quantized variants of the above
43+
// models (torch.quantization.quantize_dynamic on nn.Linear layers).
4244
const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATIONS = {
4345
// aten operations (core tensor computations)
4446
"aten::Int"sv,
@@ -79,6 +81,7 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATI
7981
"aten::mean"sv,
8082
"aten::min"sv,
8183
"aten::mul"sv,
84+
"aten::mul_"sv,
8285
"aten::ne"sv,
8386
"aten::neg"sv,
8487
"aten::new_ones"sv,
@@ -124,6 +127,8 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATI
124127
"prim::dtype"sv,
125128
"prim::max"sv,
126129
"prim::min"sv,
130+
// quantized operations (dynamically quantized models, e.g. ELSER v2)
131+
"quantized::linear_dynamic"sv,
127132
};
128133
}
129134
}

bin/pytorch_inference/unittest/testfiles/reference_model_ops.json

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@
267267
},
268268
"elastic-eis-elser-v2": {
269269
"model_id": "elastic/eis-elser-v2",
270+
"quantized": false,
270271
"ops": [
271272
"aten::Int",
272273
"aten::ScalarImplicit",
@@ -303,6 +304,7 @@
303304
},
304305
"elastic-elser-v2": {
305306
"model_id": "elastic/elser-v2",
307+
"quantized": false,
306308
"ops": [
307309
"aten::Int",
308310
"aten::ScalarImplicit",
@@ -337,6 +339,44 @@
337339
"prim::NumToTensor"
338340
]
339341
},
342+
"elastic-elser-v2-quantized": {
343+
"model_id": "elastic/elser-v2",
344+
"quantized": true,
345+
"ops": [
346+
"aten::Int",
347+
"aten::ScalarImplicit",
348+
"aten::__and__",
349+
"aten::add",
350+
"aten::arange",
351+
"aten::contiguous",
352+
"aten::dropout",
353+
"aten::embedding",
354+
"aten::expand",
355+
"aten::gather",
356+
"aten::ge",
357+
"aten::gelu",
358+
"aten::index",
359+
"aten::layer_norm",
360+
"aten::mul_",
361+
"aten::new_ones",
362+
"aten::reshape",
363+
"aten::scaled_dot_product_attention",
364+
"aten::select",
365+
"aten::size",
366+
"aten::slice",
367+
"aten::tanh",
368+
"aten::to",
369+
"aten::transpose",
370+
"aten::unsqueeze",
371+
"aten::view",
372+
"prim::Constant",
373+
"prim::DictConstruct",
374+
"prim::GetAttr",
375+
"prim::ListConstruct",
376+
"prim::NumToTensor",
377+
"quantized::linear_dynamic"
378+
]
379+
},
340380
"elastic-hugging-face-elser": {
341381
"model_id": "elastic/hugging-face-elser",
342382
"ops": [

dev-tools/extract_model_ops/extract_model_ops.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,19 +41,40 @@
4141
DEFAULT_CONFIG = SCRIPT_DIR / "reference_models.json"
4242

4343

44-
def load_reference_models(config_path: Path) -> dict[str, str]:
45-
"""Load the architecture-to-model mapping from a JSON config file."""
46-
with open(config_path) as f:
47-
return json.load(f)
44+
def load_reference_models(config_path: Path) -> dict[str, dict]:
45+
"""Load the architecture-to-model mapping from a JSON config file.
4846
47+
Each entry is either a plain model name string or a dict with
48+
``model_id`` and optional ``quantize`` flag. All entries are
49+
normalised to ``{"model_id": ..., "quantize": ...}`` dicts.
50+
Keys starting with ``_comment`` are ignored.
51+
"""
52+
with open(config_path) as f:
53+
raw = json.load(f)
4954

50-
def extract_ops_for_model(model_name: str) -> set[str] | None:
55+
models = {}
56+
for key, value in raw.items():
57+
if key.startswith("_comment"):
58+
continue
59+
if isinstance(value, str):
60+
models[key] = {"model_id": value, "quantize": False}
61+
else:
62+
models[key] = {
63+
"model_id": value["model_id"],
64+
"quantize": value.get("quantize", False),
65+
}
66+
return models
67+
68+
69+
def extract_ops_for_model(model_name: str,
70+
quantize: bool = False) -> set[str] | None:
5171
"""Trace a HuggingFace model and return its TorchScript op set.
5272
5373
Returns None if the model could not be loaded or traced.
5474
"""
55-
print(f" Loading {model_name}...", file=sys.stderr)
56-
traced = load_and_trace_hf_model(model_name)
75+
label = f"{model_name} (quantized)" if quantize else model_name
76+
print(f" Loading {label}...", file=sys.stderr)
77+
traced = load_and_trace_hf_model(model_name, quantize=quantize)
5778
if traced is None:
5879
return None
5980
return collect_inlined_ops(traced)
@@ -90,8 +111,9 @@ def main():
90111
file=sys.stderr)
91112

92113
failed = []
93-
for arch, model_name in reference_models.items():
94-
ops = extract_ops_for_model(model_name)
114+
for arch, spec in reference_models.items():
115+
ops = extract_ops_for_model(spec["model_id"],
116+
quantize=spec["quantize"])
95117
if ops is None:
96118
failed.append(arch)
97119
print(f" {arch}: FAILED", file=sys.stderr)
@@ -109,7 +131,8 @@ def main():
109131
"pytorch_version": torch.__version__,
110132
"models": {
111133
arch: {
112-
"model_id": reference_models[arch],
134+
"model_id": reference_models[arch]["model_id"],
135+
"quantized": reference_models[arch]["quantize"],
113136
"ops": sorted(ops),
114137
}
115138
for arch, ops in sorted(per_model_ops.items())
@@ -125,7 +148,11 @@ def main():
125148

126149
if args.per_model:
127150
for arch, ops in sorted(per_model_ops.items()):
128-
print(f"\n=== {arch} ({reference_models[arch]}) ===")
151+
spec = reference_models[arch]
152+
label = spec["model_id"]
153+
if spec["quantize"]:
154+
label += " (quantized)"
155+
print(f"\n=== {arch} ({label}) ===")
129156
for op in sorted(ops):
130157
print(f" {op}")
131158

dev-tools/extract_model_ops/reference_models.json

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,10 @@
1616
"elastic-hugging-face-elser": "elastic/hugging-face-elser",
1717
"elastic-multilingual-e5-small-optimized": "elastic/multilingual-e5-small-optimized",
1818
"elastic-splade-v3": "elastic/splade-v3",
19-
"elastic-test-elser-v2": "elastic/test-elser-v2"
19+
"elastic-test-elser-v2": "elastic/test-elser-v2",
20+
21+
"_comment:quantized": "Quantized variants: Eland applies torch.quantization.quantize_dynamic on nn.Linear layers when importing models. These produce quantized::* ops not present in the standard traced graphs above.",
22+
"elastic-elser-v2-quantized": {"model_id": "elastic/elser-v2", "quantize": true},
23+
"elastic-eis-elser-v2-quantized": {"model_id": "elastic/eis-elser-v2", "quantize": true},
24+
"elastic-test-elser-v2-quantized": {"model_id": "elastic/test-elser-v2", "quantize": true}
2025
}

dev-tools/extract_model_ops/torchscript_utils.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,13 @@ def collect_inlined_ops(module) -> set[str]:
3535
return collect_graph_ops(graph)
3636

3737

38-
def load_and_trace_hf_model(model_name: str):
38+
def load_and_trace_hf_model(model_name: str, quantize: bool = False):
3939
"""Load a HuggingFace model, tokenize sample input, and trace to TorchScript.
4040
41+
When *quantize* is True the model is dynamically quantized (nn.Linear
42+
layers converted to quantized::linear_dynamic) before tracing. This
43+
mirrors what Eland does when importing models for Elasticsearch.
44+
4145
Returns the traced module, or None if the model could not be loaded or traced.
4246
"""
4347
token = os.environ.get("HF_TOKEN")
@@ -53,6 +57,16 @@ def load_and_trace_hf_model(model_name: str):
5357
print(f" LOAD ERROR: {exc}", file=sys.stderr)
5458
return None
5559

60+
if quantize:
61+
try:
62+
model = torch.quantization.quantize_dynamic(
63+
model, {torch.nn.Linear}, dtype=torch.qint8)
64+
print(" Applied dynamic quantization (nn.Linear -> qint8)",
65+
file=sys.stderr)
66+
except Exception as exc:
67+
print(f" QUANTIZE ERROR: {exc}", file=sys.stderr)
68+
return None
69+
5670
inputs = tokenizer(
5771
"This is a sample input for graph extraction.",
5872
return_tensors="pt", padding="max_length",

dev-tools/extract_model_ops/validate_allowlist.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,12 @@ def check_ops(ops: set[str],
103103
def validate_model(model_name: str,
104104
allowed: set[str],
105105
forbidden: set[str],
106-
verbose: bool) -> bool:
106+
verbose: bool,
107+
quantize: bool = False) -> bool:
107108
"""Validate one HuggingFace model. Returns True if all ops pass."""
108-
print(f" {model_name}...", file=sys.stderr)
109-
traced = load_and_trace_hf_model(model_name)
109+
label = f"{model_name} (quantized)" if quantize else model_name
110+
print(f" {label}...", file=sys.stderr)
111+
traced = load_and_trace_hf_model(model_name, quantize=quantize)
110112
if traced is None:
111113
print(f" FAILED (could not load/trace)", file=sys.stderr)
112114
return False
@@ -152,13 +154,27 @@ def main():
152154
results: dict[str, bool] = {}
153155

154156
with open(args.config) as f:
155-
models = json.load(f)
157+
raw_models = json.load(f)
158+
159+
models = {}
160+
for key, value in raw_models.items():
161+
if key.startswith("_comment"):
162+
continue
163+
if isinstance(value, str):
164+
models[key] = {"model_id": value, "quantize": False}
165+
else:
166+
models[key] = {
167+
"model_id": value["model_id"],
168+
"quantize": value.get("quantize", False),
169+
}
170+
156171
print(f"Validating {len(models)} HuggingFace models from "
157172
f"{args.config.name}...", file=sys.stderr)
158173

159-
for arch, model_id in models.items():
174+
for arch, spec in models.items():
160175
results[arch] = validate_model(
161-
model_id, allowed, forbidden, args.verbose)
176+
spec["model_id"], allowed, forbidden, args.verbose,
177+
quantize=spec["quantize"])
162178

163179
if args.pt_dir and args.pt_dir.is_dir():
164180
pt_files = sorted(args.pt_dir.glob("*.pt"))
@@ -178,7 +194,11 @@ def main():
178194
if key.startswith("pt:"):
179195
print(f" {key}: {status}", file=sys.stderr)
180196
else:
181-
print(f" {key} ({models[key]}): {status}", file=sys.stderr)
197+
spec = models[key]
198+
label = spec["model_id"]
199+
if spec["quantize"]:
200+
label += " (quantized)"
201+
print(f" {key} ({label}): {status}", file=sys.stderr)
182202

183203
print("=" * 60, file=sys.stderr)
184204
if all_pass:

dev-tools/extract_model_ops/validation_models.json

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919
"elastic-splade-v3": "elastic/splade-v3",
2020
"elastic-test-elser-v2": "elastic/test-elser-v2",
2121

22+
"elastic-elser-v2-quantized": {"model_id": "elastic/elser-v2", "quantize": true},
23+
"elastic-eis-elser-v2-quantized": {"model_id": "elastic/eis-elser-v2", "quantize": true},
24+
"elastic-test-elser-v2-quantized": {"model_id": "elastic/test-elser-v2", "quantize": true},
25+
2226
"ner-dslim-bert-base": "dslim/bert-base-NER",
2327
"sentiment-distilbert-sst2": "distilbert-base-uncased-finetuned-sst-2-english",
2428

0 commit comments

Comments
 (0)