diff --git a/RATapi/inputs.py b/RATapi/inputs.py index e362cf6d..ad1e35b2 100644 --- a/RATapi/inputs.py +++ b/RATapi/inputs.py @@ -154,6 +154,22 @@ def make_problem(project: RATapi.Project) -> ProblemDefinition: hydrate_id = {"bulk in": 1, "bulk out": 2} prior_id = {"uniform": 1, "gaussian": 2, "jeffreys": 3} + # Ensure backgrounds and resolutions have a source defined + for contrast in project.contrasts: + background = project.backgrounds[contrast.background] + resolution = project.resolutions[contrast.resolution] + if background.source == "": + raise ValueError( + f"All backgrounds must have a source defined. For a {background.type} type background, " + f"the source must be defined in " + f'"{RATapi.project.values_defined_in[f"backgrounds.{background.type}.source"]}"' + ) + if resolution.source == "" and resolution.type != TypeOptions.Data: + raise ValueError( + f"Constant resolutions must have a source defined. The source must be defined in " + f'"{RATapi.project.values_defined_in[f"resolutions.{resolution.type}.source"]}"' + ) + # Set contrast parameters according to model type if project.model == LayerModels.StandardLayers: if project.calculation == Calculations.Domains: @@ -194,9 +210,9 @@ def make_problem(project: RATapi.Project) -> ProblemDefinition: contrast_resolution_params = [] contrast_resolution_types = [] - # set data, background and resolution for each contrast + # Set data, background and resolution for each contrast for contrast in project.contrasts: - # set data + # Set data data_index = project.data.index(contrast.data) data = project.data[data_index].data data_range = project.data[data_index].data_range @@ -212,7 +228,7 @@ def make_problem(project: RATapi.Project) -> ProblemDefinition: else: simulation_limits.append([0.0, 0.0]) - # set background parameters + # Set background parameters background = project.backgrounds[contrast.background] contrast_background_types.append(background.type) contrast_background_param = [] @@ -221,7 +237,7 @@ def make_problem(project: RATapi.Project) -> ProblemDefinition: contrast_background_param.append(project.data.index(background.source, True)) if background.value_1 != "": contrast_background_param.append(project.background_parameters.index(background.value_1)) - # if we are using a data background, we add the background data to the contrast data + # If we are using a data background, we add the background data to the contrast data data = append_data_background(data, project.data[background.source].data) elif background.type == TypeOptions.Function: @@ -245,7 +261,7 @@ def make_problem(project: RATapi.Project) -> ProblemDefinition: contrast_background_params.append(contrast_background_param) - # set resolution parameters + # Set resolution parameters resolution = project.resolutions[contrast.resolution] contrast_resolution_types.append(resolution.type) contrast_resolution_param = [] @@ -270,7 +286,7 @@ def make_problem(project: RATapi.Project) -> ProblemDefinition: contrast_resolution_params.append(contrast_resolution_param) - # contrast data has exactly six columns to include background data if relevant + # Contrast data has exactly six columns to include background data if relevant all_data.append(np.column_stack((data, np.zeros((data.shape[0], 6 - data.shape[1]))))) problem = ProblemDefinition() diff --git a/RATapi/models.py b/RATapi/models.py index e8e207ae..03b83e36 100644 --- a/RATapi/models.py +++ b/RATapi/models.py @@ -340,13 +340,13 @@ class Data(RATModel, arbitrary_types_allowed=True): """ name: str = Field(default_factory=lambda: f"New Data {next(data_number)}", min_length=1) - data: np.ndarray[np.float64] = np.empty([0, 3]) + data: np.ndarray = np.empty([0, 3]) data_range: list[float] = Field(default=[], min_length=2, max_length=2) simulation_range: list[float] = Field(default=[], min_length=2, max_length=2) @field_validator("data") @classmethod - def check_data_dimension(cls, data: np.ndarray[float]) -> np.ndarray[float]: + def check_data_dimension(cls, data: np.ndarray) -> np.ndarray: """Ensure the data is be a two-dimensional array containing at least three columns.""" try: data.shape[1]