From 3414de4ecec47272d87bb01c81fa84d67af32607 Mon Sep 17 00:00:00 2001 From: Alexandru Ormenisan Date: Tue, 20 Aug 2024 10:04:16 +0200 Subject: [PATCH] [HWORKS-1535] Infer model schema from feature view --- python/hsml/model.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/python/hsml/model.py b/python/hsml/model.py index e6147d5f..fb3b1a95 100644 --- a/python/hsml/model.py +++ b/python/hsml/model.py @@ -21,7 +21,7 @@ from typing import Any, Dict, Optional, Union import humps -from hsml import client, util +from hsml import ModelSchema, client, util from hsml.constants import ARTIFACT_VERSION from hsml.constants import INFERENCE_ENDPOINTS as IE from hsml.core import explicit_provenance @@ -109,6 +109,20 @@ def __init__( util.ProvenanceWarning, stacklevel=1, ) + if self._model_schema is None: + if self._feature_view is None or self._training_dataset_version is None: + warnings.warn( + "Model schema can only be inferred if both a feature view and training dataset version are known", + util.ProvenanceWarning, + stacklevel=1, + ) + else: + features, labels = self._feature_view.get_training_data( + training_dataset_version=self._training_dataset_version + ) + self._model_schema = ModelSchema( + input_schema=features, output_schema=labels + ) def save( self, @@ -133,6 +147,19 @@ def save( # Returns `Model`: The model metadata object. """ + + if self._model_schema is None: + if ( + self._feature_view is not None + and self._training_dataset_version is not None + ): + features, labels = self._feature_view.get_training_data( + training_dataset_version=self._training_dataset_version + ) + self._model_schema = ModelSchema( + input_schema=features, output_schema=labels + ) + return self._model_engine.save( model_instance=self, model_path=model_path,