Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 77 additions & 3 deletions src/bindings.cpp.in
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,9 @@ public:
QOCOSettings *get_settings();
PyQOCOSolution &get_solution();

QOCOInt update_settings(const QOCOSettings &);
// QOCOInt update_vector_data(py::object, py::object, py::object);
// QOCOInt update_matrix_data(py::object, py::object, py::object);
QOCOInt update_settings(const QOCOSettings &new_settings);
void update_vector_data(py::object cnew, py::object bnew, py::object hnew);
void update_matrix_data(py::object Pxnew, py::object Axnew, py::object Gxnew);

QOCOInt solve();

Expand Down Expand Up @@ -241,6 +241,78 @@ QOCOInt PyQOCOSolver::update_settings(const QOCOSettings &new_settings)
return qoco_update_settings(this->_solver, &new_settings);
}

void PyQOCOSolver::update_vector_data(py::object cnew, py::object bnew, py::object hnew)
{
QOCOFloat *cnew_ptr = nullptr;
QOCOFloat *bnew_ptr = nullptr;
QOCOFloat *hnew_ptr = nullptr;

if (cnew != py::none())
{
auto cnew_arr = cnew.cast<py::array_t<QOCOFloat>>();
auto buf = cnew_arr.request();
if (buf.shape[0] != this->n)
throw std::runtime_error("cnew size must be n = " + std::to_string(this->n));
cnew_ptr = (QOCOFloat *)buf.ptr;
}

if (bnew != py::none())
{
auto bnew_arr = bnew.cast<py::array_t<QOCOFloat>>();
auto buf = bnew_arr.request();
if (buf.shape[0] != this->p)
throw std::runtime_error("bnew size must be p = " + std::to_string(this->p));
bnew_ptr = (QOCOFloat *)buf.ptr;
}

if (hnew != py::none())
{
auto hnew_arr = hnew.cast<py::array_t<QOCOFloat>>();
auto buf = hnew_arr.request();
if (buf.shape[0] != this->m)
throw std::runtime_error("hnew size must be m = " + std::to_string(this->m));
hnew_ptr = (QOCOFloat *)buf.ptr;
}

qoco_update_vector_data(this->_solver, cnew_ptr, bnew_ptr, hnew_ptr);
}

void PyQOCOSolver::update_matrix_data(py::object Pxnew, py::object Axnew, py::object Gxnew)
{
QOCOFloat *Pxnew_ptr = nullptr;
QOCOFloat *Axnew_ptr = nullptr;
QOCOFloat *Gxnew_ptr = nullptr;

if (Pxnew != py::none())
{
auto Pxnew_arr = Pxnew.cast<py::array_t<QOCOFloat>>();
auto buf = Pxnew_arr.request();
if (buf.ndim != 1)
throw std::runtime_error("Pxnew must be 1-D array");
Pxnew_ptr = (QOCOFloat *)buf.ptr;
}

if (Axnew != py::none())
{
auto Axnew_arr = Axnew.cast<py::array_t<QOCOFloat>>();
auto buf = Axnew_arr.request();
if (buf.ndim != 1)
throw std::runtime_error("Axnew must be 1-D array");
Axnew_ptr = (QOCOFloat *)buf.ptr;
}

if (Gxnew != py::none())
{
auto Gxnew_arr = Gxnew.cast<py::array_t<QOCOFloat>>();
auto buf = Gxnew_arr.request();
if (buf.ndim != 1)
throw std::runtime_error("Gxnew must be 1-D array");
Gxnew_ptr = (QOCOFloat *)buf.ptr;
}

qoco_update_matrix_data(this->_solver, Pxnew_ptr, Axnew_ptr, Gxnew_ptr);
}

PYBIND11_MODULE(@QOCO_EXT_MODULE_NAME@, m)
{
// Enums.
Expand Down Expand Up @@ -308,6 +380,8 @@ PYBIND11_MODULE(@QOCO_EXT_MODULE_NAME@, m)
.def(py::init<QOCOInt, QOCOInt, QOCOInt, const CSC &, const py::array_t<QOCOFloat>, const CSC &, const py::array_t<QOCOFloat>, const CSC &, const py::array_t<QOCOFloat>, QOCOInt, QOCOInt, const py::array_t<QOCOInt>, QOCOSettings *>(), "n"_a, "m"_a, "p"_a, "P"_a, "c"_a.noconvert(), "A"_a, "b"_a.noconvert(), "G"_a, "h"_a.noconvert(), "l"_a, "nsoc"_a, "q"_a.noconvert(), "settings"_a)
.def_property_readonly("solution", &PyQOCOSolver::get_solution, py::return_value_policy::reference)
.def("update_settings", &PyQOCOSolver::update_settings)
.def("update_vector_data", &PyQOCOSolver::update_vector_data, "cnew"_a=py::none(), "bnew"_a=py::none(), "hnew"_a=py::none())
.def("update_matrix_data", &PyQOCOSolver::update_matrix_data, "Pxnew"_a=py::none(), "Axnew"_a=py::none(), "Gxnew"_a=py::none())
.def("solve", &PyQOCOSolver::solve)
.def("get_settings", &PyQOCOSolver::get_settings, py::return_value_policy::reference);
}
77 changes: 77 additions & 0 deletions src/qoco/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,83 @@ def update_settings(self, **kwargs):
if settings_changed and self._solver is not None:
self._solver.update_settings(self.settings)

def update_vector_data(self, c=None, b=None, h=None):
"""
Update data vectors.

Parameters
----------
c : np.ndarray, optional
New c vector of size n. If None, c is not updated. Default is None.
b : np.ndarray, optional
New b vector of size p. If None, b is not updated. Default is None.
h : np.ndarray, optional
New h vector of size m. If None, h is not updated. Default is None.
"""
if c is not None:
if not isinstance(c, np.ndarray):
c = np.array(c)
c = c.astype(np.float64)
if c.shape[0] != self.n:
raise ValueError(f"c size must be n = {self.n}")

if b is not None:
if not isinstance(b, np.ndarray):
b = np.array(b)
b = b.astype(np.float64)
if b.shape[0] != self.p:
raise ValueError(f"b size must be p = {self.p}")

if h is not None:
if not isinstance(h, np.ndarray):
h = np.array(h)
h = h.astype(np.float64)
if h.shape[0] != self.m:
raise ValueError(f"h size must be m = {self.m}")

return self._solver.update_vector_data(c, b, h)

def update_matrix_data(self, P=None, A=None, G=None):
"""
Update sparse matrix data.

The new matrices must have the same sparsity structure as the original ones.

Parameters
----------
P : np.ndarray, optional
New data for P matrix (only the nonzero values). If None, P is not updated.
Default is None.
A : np.ndarray, optional
New data for A matrix (only the nonzero values). If None, A is not updated.
Default is None.
G : np.ndarray, optional
New data for G matrix (only the nonzero values). If None, G is not updated.
Default is None.
"""
if P is not None:
if not isinstance(P, np.ndarray):
P = np.array(P)
P = P.astype(np.float64)
if P.shape[0] != self.P.nnz:
raise ValueError(f"P size must be {self.P.nnz}")

if A is not None:
if not isinstance(A, np.ndarray):
A = np.array(A)
A = A.astype(np.float64)
if A.shape[0] != self.A.nnz:
raise ValueError(f"A size must be {self.A.nnz}")

if G is not None:
if not isinstance(G, np.ndarray):
G = np.array(G)
G = G.astype(np.float64)
if G.shape[0] != self.G.nnz:
raise ValueError(f"G size must be {self.G.nnz}")

return self._solver.update_matrix_data(P, A, G)

def setup(self, n, m, p, P, c, A, b, G, h, l, nsoc, q, **settings):
self.m = m
self.n = n
Expand Down
Loading