Skip to content

Commit 2fe8a27

Browse files
authored
[ML] Add aten::split and aten::stack for question-answering models (#3012)
The deepset/tinyroberta-squad2 model uses aten::split (and aten::stack per ES node logs) in its answer span extraction logic. These ops only appear when traced with AutoModelForQuestionAnswering rather than AutoModel. Update the extraction configs to use the correct auto_class. Also verified that LaBSE, BAAI/bge-reranker-base, and castorini/bpr-nq-ctx-encoder (from the supported models docs) are all covered by the existing allowlist. Made-with: Cursor
1 parent 48a1e66 commit 2fe8a27

4 files changed

Lines changed: 10 additions & 7 deletions

File tree

bin/pytorch_inference/CSupportedOperations.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,10 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATI
124124
"aten::size"sv,
125125
"aten::slice"sv,
126126
"aten::softmax"sv,
127+
"aten::split"sv,
127128
"aten::sqrt"sv,
128129
"aten::squeeze"sv,
130+
"aten::stack"sv,
129131
"aten::str"sv,
130132
"aten::sub"sv,
131133
"aten::sum"sv,

bin/pytorch_inference/unittest/testfiles/reference_model_ops.json

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,9 +1011,9 @@
10111011
"model_id": "deepset/tinyroberta-squad2",
10121012
"quantized": false,
10131013
"ops": [
1014-
"aten::Int",
10151014
"aten::add",
10161015
"aten::add_",
1016+
"aten::contiguous",
10171017
"aten::cumsum",
10181018
"aten::detach",
10191019
"aten::dropout",
@@ -1027,11 +1027,11 @@
10271027
"aten::ne",
10281028
"aten::reshape",
10291029
"aten::scaled_dot_product_attention",
1030-
"aten::select",
10311030
"aten::size",
10321031
"aten::slice",
1032+
"aten::split",
1033+
"aten::squeeze",
10331034
"aten::sub",
1034-
"aten::tanh",
10351035
"aten::to",
10361036
"aten::transpose",
10371037
"aten::type_as",
@@ -1040,9 +1040,10 @@
10401040
"prim::Constant",
10411041
"prim::GetAttr",
10421042
"prim::ListConstruct",
1043-
"prim::NumToTensor",
1043+
"prim::ListUnpack",
10441044
"prim::TupleConstruct"
1045-
]
1045+
],
1046+
"auto_class": "AutoModelForQuestionAnswering"
10461047
},
10471048
"qa-bart-large-mnli": {
10481049
"model_id": "facebook/bart-large-mnli",

dev-tools/extract_model_ops/reference_models.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
"elastic-test-elser-v2-quantized": {"model_id": "elastic/test-elser-v2", "quantized": true},
3232

3333
"_comment:qa-models": "Models from the Appex QA pytorch_tests suite. BART models require auto_class and config_overrides to trace correctly.",
34-
"qa-tinyroberta-squad2": "deepset/tinyroberta-squad2",
34+
"qa-tinyroberta-squad2": {"model_id": "deepset/tinyroberta-squad2", "auto_class": "AutoModelForQuestionAnswering"},
3535
"qa-squeezebert-mnli": "typeform/squeezebert-mnli",
3636
"qa-bart-large-mnli": {"model_id": "facebook/bart-large-mnli", "auto_class": "AutoModelForSequenceClassification", "config_overrides": {"use_cache": false}},
3737
"qa-distilbart-mnli": {"model_id": "valhalla/distilbart-mnli-12-6", "auto_class": "AutoModelForSequenceClassification", "config_overrides": {"use_cache": false}}

dev-tools/extract_model_ops/validation_models.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
"es-dpr-question-encoder": "facebook/dpr-question_encoder-single-nq-base",
3333

3434
"_comment:qa-models": "Models from the Appex QA pytorch_tests suite. BART models require auto_class and config_overrides to trace correctly.",
35-
"qa-tinyroberta-squad2": "deepset/tinyroberta-squad2",
35+
"qa-tinyroberta-squad2": {"model_id": "deepset/tinyroberta-squad2", "auto_class": "AutoModelForQuestionAnswering"},
3636
"qa-squeezebert-mnli": "typeform/squeezebert-mnli",
3737
"qa-bart-large-mnli": {"model_id": "facebook/bart-large-mnli", "auto_class": "AutoModelForSequenceClassification", "config_overrides": {"use_cache": false}},
3838
"qa-distilbart-mnli": {"model_id": "valhalla/distilbart-mnli-12-6", "auto_class": "AutoModelForSequenceClassification", "config_overrides": {"use_cache": false}}

0 commit comments

Comments
 (0)