@@ -104,30 +104,41 @@ def validate_model(model_name: str,
104104 allowed : set [str ],
105105 forbidden : set [str ],
106106 verbose : bool ,
107- quantize : bool = False ) -> bool :
108- """Validate one HuggingFace model. Returns True if all ops pass."""
107+ quantize : bool = False ,
108+ auto_class : str | None = None ,
109+ config_overrides : dict | None = None ) -> str :
110+ """Validate one HuggingFace model.
111+
112+ Returns "pass", "fail" (op validation failed), or "skip" (could not
113+ load/trace — e.g. private model without HF_TOKEN).
114+ """
109115 label = f"{ model_name } (quantized)" if quantize else model_name
110116 print (f" { label } ..." , file = sys .stderr )
111- traced = load_and_trace_hf_model (model_name , quantize = quantize )
117+ traced = load_and_trace_hf_model (model_name , quantize = quantize ,
118+ auto_class = auto_class ,
119+ config_overrides = config_overrides )
112120 if traced is None :
113- print (f" FAILED (could not load/trace)" , file = sys .stderr )
114- return False
121+ print (f" SKIPPED (could not load/trace)" , file = sys .stderr )
122+ return "skip"
115123 ops = collect_inlined_ops (traced )
116- return check_ops (ops , allowed , forbidden , verbose )
124+ return "pass" if check_ops (ops , allowed , forbidden , verbose ) else "fail"
117125
118126
119127def validate_pt_file (name : str ,
120128 pt_path : str ,
121129 allowed : set [str ],
122130 forbidden : set [str ],
123- verbose : bool ) -> bool :
124- """Validate a local TorchScript .pt file. Returns True if all ops pass."""
131+ verbose : bool ) -> str :
132+ """Validate a local TorchScript .pt file.
133+
134+ Returns "pass", "fail", or "skip".
135+ """
125136 print (f" { name } ({ pt_path } )..." , file = sys .stderr )
126137 ops = load_pt_and_collect_ops (pt_path )
127138 if ops is None :
128- print (f" FAILED (could not load)" , file = sys .stderr )
129- return False
130- return check_ops (ops , allowed , forbidden , verbose )
139+ print (f" SKIPPED (could not load)" , file = sys .stderr )
140+ return "skip"
141+ return "pass" if check_ops (ops , allowed , forbidden , verbose ) else "fail"
131142
132143
133144def main ():
@@ -151,7 +162,7 @@ def main():
151162 print (f"Parsed { len (allowed )} allowed ops and { len (forbidden )} "
152163 f"forbidden ops from { SUPPORTED_OPS_CC .name } " , file = sys .stderr )
153164
154- results : dict [str , bool ] = {}
165+ results : dict [str , str ] = {}
155166
156167 models = load_model_config (args .config )
157168
@@ -161,7 +172,9 @@ def main():
161172 for arch , spec in models .items ():
162173 results [arch ] = validate_model (
163174 spec ["model_id" ], allowed , forbidden , args .verbose ,
164- quantize = spec ["quantized" ])
175+ quantize = spec ["quantized" ],
176+ auto_class = spec .get ("auto_class" ),
177+ config_overrides = spec .get ("config_overrides" ))
165178
166179 if args .pt_dir and args .pt_dir .is_dir ():
167180 pt_files = sorted (args .pt_dir .glob ("*.pt" ))
@@ -175,26 +188,32 @@ def main():
175188
176189 print (file = sys .stderr )
177190 print ("=" * 60 , file = sys .stderr )
178- all_pass = all (results .values ())
179- for key , passed in results .items ():
180- status = "PASS" if passed else "FAIL"
191+ for key , status in results .items ():
192+ display = status .upper ()
181193 if key .startswith ("pt:" ):
182- print (f" { key } : { status } " , file = sys .stderr )
194+ print (f" { key } : { display } " , file = sys .stderr )
183195 else :
184196 spec = models [key ]
185197 label = spec ["model_id" ]
186198 if spec ["quantized" ]:
187199 label += " (quantized)"
188- print (f" { key } ({ label } ): { status } " , file = sys .stderr )
200+ print (f" { key } ({ label } ): { display } " , file = sys .stderr )
201+
202+ failed = [a for a , s in results .items () if s == "fail" ]
203+ skipped = [a for a , s in results .items () if s == "skip" ]
204+ passed = [a for a , s in results .items () if s == "pass" ]
189205
190206 print ("=" * 60 , file = sys .stderr )
191- if all_pass :
192- print ("All models PASS - no false positives." , file = sys .stderr )
193- else :
194- failed = [a for a , p in results .items () if not p ]
195- print (f"FAILED models: { ', ' .join (failed )} " , file = sys .stderr )
207+ print (f"{ len (passed )} passed, { len (failed )} failed, "
208+ f"{ len (skipped )} skipped" , file = sys .stderr )
209+
210+ if skipped :
211+ print (f"Skipped (could not load/trace — may need HF_TOKEN "
212+ f"for private models): { ', ' .join (skipped )} " , file = sys .stderr )
213+ if failed :
214+ print (f"FAILED (op validation): { ', ' .join (failed )} " , file = sys .stderr )
196215
197- sys .exit (0 if all_pass else 1 )
216+ sys .exit (0 if not failed else 1 )
198217
199218
200219if __name__ == "__main__" :
0 commit comments