Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion RATapi/examples/normal_reflectivity/background_function.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np


def backgroundFunction(xdata, params):
def background_function(xdata, params):
# Split up the params array
Ao = params[0]
k = params[1]
Expand Down
85 changes: 37 additions & 48 deletions RATapi/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,19 @@
import RATapi
import RATapi.controls
import RATapi.wrappers
from RATapi.rat_core import Checks, Control, Limits, NameStore, Priors, ProblemDefinition
from RATapi.rat_core import Checks, Control, Limits, NameStore, ProblemDefinition
from RATapi.utils.enums import Calculations, Languages, LayerModels, TypeOptions

parameter_field = {
"parameters": "params",
"bulk_in": "bulkIns",
"bulk_out": "bulkOuts",
"scalefactors": "scalefactors",
"domain_ratios": "domainRatios",
"background_parameters": "backgroundParams",
"resolution_parameters": "resolutionParams",
}


def get_python_handle(file_name: str, function_name: str, path: Union[str, pathlib.Path] = "") -> Callable:
"""Get the function handle from a function defined in a python module located anywhere within the filesystem.
Expand Down Expand Up @@ -94,7 +104,7 @@ def __len__(self):
return len(self.files)


def make_input(project: RATapi.Project, controls: RATapi.Controls) -> tuple[ProblemDefinition, Limits, Priors, Control]:
def make_input(project: RATapi.Project, controls: RATapi.Controls) -> tuple[ProblemDefinition, Limits, Control]:
"""Constructs the inputs required for the compiled RAT code using the data defined in the input project and
controls.

Expand All @@ -111,65 +121,32 @@ def make_input(project: RATapi.Project, controls: RATapi.Controls) -> tuple[Prob
The problem input used in the compiled RAT code.
limits : RAT.rat_core.Limits
A list of min/max values for each parameter defined in the project.
priors : RAT.rat_core.Priors
The priors defined for each parameter in the project.
cpp_controls : RAT.rat_core.Control
The controls object used in the compiled RAT code.

"""
parameter_field = {
"parameters": "params",
"bulk_in": "bulkIns",
"bulk_out": "bulkOuts",
"scalefactors": "scalefactors",
"domain_ratios": "domainRatios",
"background_parameters": "backgroundParams",
"resolution_parameters": "resolutionParams",
}

prior_id = {"uniform": 1, "gaussian": 2, "jeffreys": 3}

checks = Checks()
limits = Limits()
priors = Priors()

for class_list in RATapi.project.parameter_class_lists:
setattr(checks, parameter_field[class_list], [int(element.fit) for element in getattr(project, class_list)])
setattr(
limits,
parameter_field[class_list],
[[element.min, element.max] for element in getattr(project, class_list)],
)
setattr(
priors,
parameter_field[class_list],
[[element.name, element.prior_type, element.mu, element.sigma] for element in getattr(project, class_list)],
)

# Use dummy values for qzshifts
checks.qzshifts = []
# Use dummy value for qzshifts
limits.qzshifts = []
priors.qzshifts = []

priors.priorNames = [
param.name for class_list in RATapi.project.parameter_class_lists for param in getattr(project, class_list)
]
priors.priorValues = [
[prior_id[param.prior_type], param.mu, param.sigma]
for class_list in RATapi.project.parameter_class_lists
for param in getattr(project, class_list)
]

if project.model == LayerModels.CustomXY:
controls.calcSldDuringFit = True

problem = make_problem(project, checks)
problem = make_problem(project)
cpp_controls = make_controls(controls)

return problem, limits, priors, cpp_controls
return problem, limits, cpp_controls


def make_problem(project: RATapi.Project, checks: Checks) -> ProblemDefinition:
def make_problem(project: RATapi.Project) -> ProblemDefinition:
"""Constructs the problem input required for the compiled RAT code.

Parameters
Expand All @@ -184,6 +161,7 @@ def make_problem(project: RATapi.Project, checks: Checks) -> ProblemDefinition:

"""
hydrate_id = {"bulk in": 1, "bulk out": 2}
prior_id = {"uniform": 1, "gaussian": 2, "jeffreys": 3}

# Set contrast parameters according to model type
if project.model == LayerModels.StandardLayers:
Expand Down Expand Up @@ -384,21 +362,32 @@ def make_problem(project: RATapi.Project, checks: Checks) -> ProblemDefinition:
for param in getattr(project, class_list)
if not param.fit
]
problem.priorNames = [
param.name for class_list in RATapi.project.parameter_class_lists for param in getattr(project, class_list)
]
problem.priorValues = [
[prior_id[param.prior_type], param.mu, param.sigma]
for class_list in RATapi.project.parameter_class_lists
for param in getattr(project, class_list)
]

# Names
problem.names = NameStore()
problem.names.params = [param.name for param in project.parameters]
problem.names.backgroundParams = [param.name for param in project.background_parameters]
problem.names.scalefactors = [param.name for param in project.scalefactors]
problem.names.qzshifts = [] # Placeholder for qzshifts
problem.names.bulkIns = [param.name for param in project.bulk_in]
problem.names.bulkOuts = [param.name for param in project.bulk_out]
problem.names.resolutionParams = [param.name for param in project.resolution_parameters]
problem.names.domainRatios = [param.name for param in project.domain_ratios]
for class_list in RATapi.project.parameter_class_lists:
setattr(problem.names, parameter_field[class_list], [param.name for param in getattr(project, class_list)])
problem.names.contrasts = [contrast.name for contrast in project.contrasts]

# Checks
problem.checks = checks
problem.checks = Checks()
for class_list in RATapi.project.parameter_class_lists:
setattr(
problem.checks, parameter_field[class_list], [int(element.fit) for element in getattr(project, class_list)]
)

# Use dummy values for qz shifts
problem.names.qzshifts = []
problem.checks.qzshifts = []

check_indices(problem)

return problem
Expand Down
3 changes: 1 addition & 2 deletions RATapi/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def run(project, controls):

horizontal_line = "\u2500" * 107 + "\n"
display_on = controls.display != Display.Off
problem_definition, limits, priors, cpp_controls = make_input(project, controls)
problem_definition, limits, cpp_controls = make_input(project, controls)

if display_on:
print("Starting RAT " + horizontal_line)
Expand All @@ -115,7 +115,6 @@ def run(project, controls):
problem_definition,
limits,
cpp_controls,
priors,
)
end = time.time()

Expand Down
2 changes: 1 addition & 1 deletion cpp/RAT
Submodule RAT updated 115 files
Loading