MueLu  Version of the Day
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
MueLu_MatlabUtils.cpp
Go to the documentation of this file.
1 // @HEADER
2 // *****************************************************************************
3 // MueLu: A package for multigrid based preconditioning
4 //
5 // Copyright 2012 NTESS and the MueLu contributors.
6 // SPDX-License-Identifier: BSD-3-Clause
7 // *****************************************************************************
8 // @HEADER
9 
11 
12 #if !defined(HAVE_MUELU_MATLAB) || !defined(HAVE_MUELU_TPETRA)
13 #error "Muemex types require MATLAB and Tpetra."
14 #else
15 
16 /* Stuff for MATLAB R2006b vs. previous versions */
17 #if (defined(MX_API_VER) && MX_API_VER >= 0x07030000)
18 #else
19 typedef int mwIndex;
20 #endif
21 
22 using namespace std;
23 using namespace Teuchos;
24 
25 namespace MueLu {
26 
27 /* Explicit instantiation of MuemexData variants */
28 template class MuemexData<RCP<Xpetra::MultiVector<double, mm_LocalOrd, mm_GlobalOrd, mm_node_t> > >;
29 template class MuemexData<RCP<Xpetra::MultiVector<complex_t, mm_LocalOrd, mm_GlobalOrd, mm_node_t> > >;
30 template class MuemexData<RCP<Xpetra::Matrix<double, mm_LocalOrd, mm_GlobalOrd, mm_node_t> > >;
31 template class MuemexData<RCP<Xpetra::Matrix<complex_t, mm_LocalOrd, mm_GlobalOrd, mm_node_t> > >;
32 template class MuemexData<RCP<MAggregates> >;
33 template class MuemexData<RCP<MAmalInfo> >;
34 template class MuemexData<int>;
35 template class MuemexData<bool>;
36 template class MuemexData<complex_t>;
37 template class MuemexData<string>;
38 template class MuemexData<double>;
39 template class MuemexData<RCP<Tpetra::CrsMatrix<double, mm_LocalOrd, mm_GlobalOrd, mm_node_t> > >;
40 template class MuemexData<RCP<Tpetra::CrsMatrix<complex_t, mm_LocalOrd, mm_GlobalOrd, mm_node_t> > >;
41 #ifdef HAVE_MUELU_EPETRA
42 template class MuemexData<RCP<Epetra_MultiVector> >;
43 #endif
44 template class MuemexData<RCP<Tpetra::MultiVector<double, mm_LocalOrd, mm_GlobalOrd, mm_node_t> > >;
45 template class MuemexData<RCP<Tpetra::MultiVector<complex_t, mm_LocalOrd, mm_GlobalOrd, mm_node_t> > >;
46 template class MuemexData<RCP<Xpetra::Vector<mm_LocalOrd, mm_LocalOrd, mm_GlobalOrd, mm_node_t> > >;
47 
48 // Flag set to true if MATLAB's CSC matrix index type is not int (usually false)
49 bool rewrap_ints = sizeof(int) != sizeof(mwIndex);
50 
51 int* mwIndex_to_int(int N, mwIndex* mwi_array) {
52  // int* rv = (int*) malloc(N * sizeof(int));
53  int* rv = new int[N]; // not really better but may avoid confusion for valgrind
54  for (int i = 0; i < N; i++)
55  rv[i] = (int)mwi_array[i];
56  return rv;
57 }
58 
59 /* ******************************* */
60 /* Specializations */
61 /* ******************************* */
62 
63 template <>
64 mxArray* createMatlabSparse<double>(int numRows, int numCols, int nnz) {
65  return mxCreateSparse(numRows, numCols, nnz, mxREAL);
66 }
67 
68 template <>
69 mxArray* createMatlabSparse<complex_t>(int numRows, int numCols, int nnz) {
70  return mxCreateSparse(numRows, numCols, nnz, mxCOMPLEX);
71 }
72 
73 template <>
74 void fillMatlabArray<double>(double* array, const mxArray* mxa, int n) {
75  memcpy(mxGetPr(mxa), array, n * sizeof(double));
76 }
77 
78 template <>
79 void fillMatlabArray<complex_t>(complex_t* array, const mxArray* mxa, int n) {
80  double* pr = mxGetPr(mxa);
81  double* pi = mxGetPi(mxa);
82  for (int i = 0; i < n; i++) {
83  pr[i] = std::real<double>(array[i]);
84  pi[i] = std::imag<double>(array[i]);
85  }
86 }
87 
88 /******************************/
89 /* Callback Functions */
90 /******************************/
91 
92 void callMatlabNoArgs(std::string function) {
93  int result = mexEvalString(function.c_str());
94  if (result != 0)
95  mexPrintf("An error occurred while running a MATLAB command.");
96 }
97 
98 std::vector<RCP<MuemexArg> > callMatlab(std::string function, int numOutputs, std::vector<RCP<MuemexArg> > args) {
99  using Teuchos::rcp_static_cast;
100  mxArray** matlabArgs = new mxArray*[args.size()];
101  mxArray** matlabOutput = new mxArray*[numOutputs];
102  std::vector<RCP<MuemexArg> > output;
103 
104  for (int i = 0; i < int(args.size()); i++) {
105  try {
106  switch (args[i]->type) {
107  case BOOL:
108  matlabArgs[i] = rcp_static_cast<MuemexData<bool>, MuemexArg>(args[i])->convertToMatlab();
109  break;
110  case INT:
111  matlabArgs[i] = rcp_static_cast<MuemexData<int>, MuemexArg>(args[i])->convertToMatlab();
112  break;
113  case DOUBLE:
114  matlabArgs[i] = rcp_static_cast<MuemexData<double>, MuemexArg>(args[i])->convertToMatlab();
115  break;
116  case STRING:
117  matlabArgs[i] = rcp_static_cast<MuemexData<std::string>, MuemexArg>(args[i])->convertToMatlab();
118  break;
119  case COMPLEX:
120  matlabArgs[i] = rcp_static_cast<MuemexData<complex_t>, MuemexArg>(args[i])->convertToMatlab();
121  break;
122  case XPETRA_MAP:
123  matlabArgs[i] = rcp_static_cast<MuemexData<RCP<Xpetra_map> >, MuemexArg>(args[i])->convertToMatlab();
124  break;
126  matlabArgs[i] = rcp_static_cast<MuemexData<RCP<Xpetra_ordinal_vector> >, MuemexArg>(args[i])->convertToMatlab();
127  break;
129  matlabArgs[i] = rcp_static_cast<MuemexData<RCP<Tpetra::MultiVector<double, mm_LocalOrd, mm_GlobalOrd, mm_node_t> > >, MuemexArg>(args[i])->convertToMatlab();
130  break;
132  matlabArgs[i] = rcp_static_cast<MuemexData<RCP<Tpetra::MultiVector<complex_t, mm_LocalOrd, mm_GlobalOrd, mm_node_t> > >, MuemexArg>(args[i])->convertToMatlab();
133  break;
135  matlabArgs[i] = rcp_static_cast<MuemexData<RCP<Tpetra::CrsMatrix<double, mm_LocalOrd, mm_GlobalOrd, mm_node_t> > >, MuemexArg>(args[i])->convertToMatlab();
136  break;
138  matlabArgs[i] = rcp_static_cast<MuemexData<RCP<Tpetra::CrsMatrix<complex_t, mm_LocalOrd, mm_GlobalOrd, mm_node_t> > >, MuemexArg>(args[i])->convertToMatlab();
139  break;
141  matlabArgs[i] = rcp_static_cast<MuemexData<RCP<Xpetra_Matrix_double> >, MuemexArg>(args[i])->convertToMatlab();
142  break;
144  matlabArgs[i] = rcp_static_cast<MuemexData<RCP<Xpetra_Matrix_complex> >, MuemexArg>(args[i])->convertToMatlab();
145  break;
147  matlabArgs[i] = rcp_static_cast<MuemexData<RCP<Xpetra_MultiVector_double> >, MuemexArg>(args[i])->convertToMatlab();
148  break;
150  matlabArgs[i] = rcp_static_cast<MuemexData<RCP<Xpetra_MultiVector_complex> >, MuemexArg>(args[i])->convertToMatlab();
151  break;
152 #ifdef HAVE_MUELU_EPETRA
153  case EPETRA_CRSMATRIX:
154  matlabArgs[i] = rcp_static_cast<MuemexData<RCP<Epetra_CrsMatrix> >, MuemexArg>(args[i])->convertToMatlab();
155  break;
156  case EPETRA_MULTIVECTOR:
157  matlabArgs[i] = rcp_static_cast<MuemexData<RCP<Epetra_MultiVector> >, MuemexArg>(args[i])->convertToMatlab();
158  break;
159 #endif
160  case AGGREGATES:
161  matlabArgs[i] = rcp_static_cast<MuemexData<RCP<MAggregates> >, MuemexArg>(args[i])->convertToMatlab();
162  break;
163  case AMALGAMATION_INFO:
164  matlabArgs[i] = rcp_static_cast<MuemexData<RCP<MAmalInfo> >, MuemexArg>(args[i])->convertToMatlab();
165  break;
166  case GRAPH:
167  matlabArgs[i] = rcp_static_cast<MuemexData<RCP<MGraph> >, MuemexArg>(args[i])->convertToMatlab();
168 #ifdef HAVE_MUELU_INTREPID2
169  case FIELDCONTAINER_ORDINAL:
170  matlabArgs[i] = rcp_static_cast<MuemexData<RCP<FieldContainer_ordinal> >, MuemexArg>(args[i])->convertToMatlab();
171  break;
172 #endif
173  }
174  } catch (std::exception& e) {
175  mexPrintf("An error occurred while converting arg #%d to MATLAB:\n", i);
176  std::cout << e.what() << std::endl;
177  mexPrintf("Passing 0 instead.\n");
178  matlabArgs[i] = mxCreateDoubleScalar(0);
179  }
180  }
181  // now matlabArgs is populated with MATLAB data types
182  int result = mexCallMATLAB(numOutputs, matlabOutput, args.size(), matlabArgs, function.c_str());
183  if (result != 0)
184  mexPrintf("Matlab encountered an error while running command through muemexCallbacks.\n");
185  // now, if all went well, matlabOutput contains all the output to return to user
186  for (int i = 0; i < numOutputs; i++) {
187  try {
188  output.push_back(convertMatlabVar(matlabOutput[i]));
189  } catch (std::exception& e) {
190  mexPrintf("An error occurred while converting output #%d from MATLAB:\n", i);
191  std::cout << e.what() << std::endl;
192  }
193  }
194  delete[] matlabOutput;
195  delete[] matlabArgs;
196  return output;
197 }
198 
199 /******************************/
200 /* More utility functions */
201 /******************************/
202 
203 template <>
204 mxArray* createMatlabMultiVector<double>(int numRows, int numCols) {
205  return mxCreateDoubleMatrix(numRows, numCols, mxREAL);
206 }
207 
208 template <>
209 mxArray* createMatlabMultiVector<complex_t>(int numRows, int numCols) {
210  return mxCreateDoubleMatrix(numRows, numCols, mxCOMPLEX);
211 }
212 
214  throw runtime_error("AmalgamationInfo not supported in MueMex yet.");
215  return mxCreateDoubleScalar(0);
216 }
217 
219  bool isValidAggregates = true;
220  if (!mxIsStruct(mxa))
221  return false;
222  int numFields = mxGetNumberOfFields(mxa); // check that struct has correct # of fields
223  if (numFields != 5)
224  isValidAggregates = false;
225  if (isValidAggregates) {
226  const char* mem1 = mxGetFieldNameByNumber(mxa, 0);
227  if (mem1 == NULL || strcmp(mem1, "nVertices") != 0)
228  isValidAggregates = false;
229  const char* mem2 = mxGetFieldNameByNumber(mxa, 1);
230  if (mem2 == NULL || strcmp(mem2, "nAggregates") != 0)
231  isValidAggregates = false;
232  const char* mem3 = mxGetFieldNameByNumber(mxa, 2);
233  if (mem3 == NULL || strcmp(mem3, "vertexToAggID") != 0)
234  isValidAggregates = false;
235  const char* mem4 = mxGetFieldNameByNumber(mxa, 3);
236  if (mem3 == NULL || strcmp(mem4, "rootNodes") != 0)
237  isValidAggregates = false;
238  const char* mem5 = mxGetFieldNameByNumber(mxa, 4);
239  if (mem4 == NULL || strcmp(mem5, "aggSizes") != 0)
240  isValidAggregates = false;
241  }
242  return isValidAggregates;
243 }
244 
245 bool isValidMatlabGraph(const mxArray* mxa) {
246  bool isValidGraph = true;
247  if (!mxIsStruct(mxa))
248  return false;
249  int numFields = mxGetNumberOfFields(mxa); // check that struct has correct # of fields
250  if (numFields != 2)
251  isValidGraph = false;
252  if (isValidGraph) {
253  const char* mem1 = mxGetFieldNameByNumber(mxa, 0);
254  if (mem1 == NULL || strcmp(mem1, "edges") != 0)
255  isValidGraph = false;
256  const char* mem2 = mxGetFieldNameByNumber(mxa, 1);
257  if (mem2 == NULL || strcmp(mem2, "boundaryNodes") != 0)
258  isValidGraph = false;
259  }
260  return isValidGraph;
261 }
262 
263 std::vector<std::string> tokenizeList(const std::string& params) {
264  using namespace std;
265  vector<string> rlist;
266  const char* delims = ",";
267  char* copy = (char*)malloc(params.length() + 1);
268  strcpy(copy, params.c_str());
269  char* mark = (char*)strtok(copy, delims);
270  while (mark != NULL) {
271  // Remove leading and trailing whitespace in token
272  char* tail = mark + strlen(mark) - 1;
273  while (*mark == ' ')
274  mark++;
275  while (*tail == ' ' && tail > mark)
276  tail--;
277  tail++;
278  *tail = 0;
279  string tok(mark); // copies the characters to string object
280  rlist.push_back(tok);
281  mark = strtok(NULL, delims);
282  }
283  free(copy);
284  return rlist;
285 }
286 
288  using namespace Teuchos;
289  RCP<ParameterList> validParamList = rcp(new ParameterList());
290  validParamList->set<RCP<const FactoryBase> >("A", Teuchos::null, "Factory for the matrix A.");
291  validParamList->set<RCP<const FactoryBase> >("P", Teuchos::null, "Factory for the prolongator.");
292  validParamList->set<RCP<const FactoryBase> >("R", Teuchos::null, "Factory for the restrictor.");
293  validParamList->set<RCP<const FactoryBase> >("Ptent", Teuchos::null, "Factory for the tentative (unsmoothed) prolongator.");
294  validParamList->set<RCP<const FactoryBase> >("Coordinates", Teuchos::null, "Factory for the node coordinates.");
295  validParamList->set<RCP<const FactoryBase> >("Nullspace", Teuchos::null, "Factory for the nullspace.");
296  validParamList->set<RCP<const FactoryBase> >("Aggregates", Teuchos::null, "Factory for the aggregates.");
297  validParamList->set<RCP<const FactoryBase> >("UnamalgamationInfo", Teuchos::null, "Factory for amalgamation.");
298 #ifdef HAVE_MUELU_INTREPID2
299  validParamList->set<RCP<const FactoryBase> >("pcoarsen: element to node map", Teuchos::null, "Generating factory of the element to node map");
300 #endif
301  return validParamList;
302 }
303 
305  switch (mxGetClassID(mxa)) {
306  case mxCHAR_CLASS:
307  // string
308  return rcp_implicit_cast<MuemexArg>(rcp(new MuemexData<std::string>(mxa)));
309  break;
310  case mxLOGICAL_CLASS:
311  // boolean
312  return rcp_implicit_cast<MuemexArg>(rcp(new MuemexData<bool>(mxa)));
313  break;
314  case mxINT32_CLASS:
315  if (mxGetM(mxa) == 1 && mxGetN(mxa) == 1)
316  // individual integer
317  return rcp_implicit_cast<MuemexArg>(rcp(new MuemexData<int>(mxa)));
318  else if (mxGetM(mxa) != 1 || mxGetN(mxa) != 1)
319  // ordinal vector
320  return rcp_implicit_cast<MuemexArg>(rcp(new MuemexData<RCP<Xpetra_ordinal_vector> >(mxa)));
321  else
322  throw std::runtime_error("Error: Don't know what to do with integer array.\n");
323  break;
324  case mxDOUBLE_CLASS:
325  if (mxGetM(mxa) == 1 && mxGetN(mxa) == 1) {
326  if (mxIsComplex(mxa))
327  // single double (scalar, real)
328  return rcp_implicit_cast<MuemexArg>(rcp(new MuemexData<complex_t>(mxa)));
329  else
330  // single complex scalar
331  return rcp_implicit_cast<MuemexArg>(rcp(new MuemexData<double>(mxa)));
332  } else if (mxIsSparse(mxa)) // use a CRS matrix
333  {
334  // Default to Tpetra matrix for this
335  if (mxIsComplex(mxa))
336  // complex matrix
337  return rcp_implicit_cast<MuemexArg>(rcp(new MuemexData<RCP<Xpetra_Matrix_complex> >(mxa)));
338  else
339  // real-valued matrix
340  return rcp_implicit_cast<MuemexArg>(rcp(new MuemexData<RCP<Xpetra_Matrix_double> >(mxa)));
341  } else {
342  // Default to Xpetra multivector for this case
343  if (mxIsComplex(mxa))
344  return rcp_implicit_cast<MuemexArg>(rcp(new MuemexData<RCP<Xpetra::MultiVector<complex_t, mm_LocalOrd, mm_GlobalOrd, mm_node_t> > >(mxa)));
345  else
346  return rcp_implicit_cast<MuemexArg>(rcp(new MuemexData<RCP<Xpetra::MultiVector<double, mm_LocalOrd, mm_GlobalOrd, mm_node_t> > >(mxa)));
347  }
348  break;
349  case mxSTRUCT_CLASS: {
350  // the only thing that should get here currently is an Aggregates struct or Graph struct
351  // verify that it has the correct fields with the correct types
352  // also assume that aggregates data will not be stored in an array of more than 1 element.
353  if (isValidMatlabAggregates(mxa)) {
354  return rcp_implicit_cast<MuemexArg>(rcp(new MuemexData<RCP<MAggregates> >(mxa)));
355  } else if (isValidMatlabGraph(mxa)) {
356  return rcp_implicit_cast<MuemexArg>(rcp(new MuemexData<RCP<MGraph> >(mxa)));
357  } else {
358  throw runtime_error("Invalid aggregates or graph struct passed in from MATLAB.");
359  return Teuchos::null;
360  }
361  break;
362  }
363  default:
364  throw std::runtime_error("MATLAB returned an unsupported type as a function output.\n");
365  return Teuchos::null;
366  }
367 }
368 
369 /******************************/
370 /* Explicit Instantiations */
371 /******************************/
372 
373 template bool loadDataFromMatlab<bool>(const mxArray* mxa);
374 template int loadDataFromMatlab<int>(const mxArray* mxa);
375 template double loadDataFromMatlab<double>(const mxArray* mxa);
376 template complex_t loadDataFromMatlab<complex_t>(const mxArray* mxa);
377 template string loadDataFromMatlab<string>(const mxArray* mxa);
378 template RCP<Xpetra_ordinal_vector> loadDataFromMatlab<RCP<Xpetra_ordinal_vector> >(const mxArray* mxa);
379 template RCP<Tpetra_MultiVector_double> loadDataFromMatlab<RCP<Tpetra_MultiVector_double> >(const mxArray* mxa);
380 template RCP<Tpetra_MultiVector_complex> loadDataFromMatlab<RCP<Tpetra_MultiVector_complex> >(const mxArray* mxa);
381 template RCP<Tpetra_CrsMatrix_double> loadDataFromMatlab<RCP<Tpetra_CrsMatrix_double> >(const mxArray* mxa);
382 template RCP<Tpetra_CrsMatrix_complex> loadDataFromMatlab<RCP<Tpetra_CrsMatrix_complex> >(const mxArray* mxa);
383 template RCP<Xpetra_Matrix_double> loadDataFromMatlab<RCP<Xpetra_Matrix_double> >(const mxArray* mxa);
384 template RCP<Xpetra_Matrix_complex> loadDataFromMatlab<RCP<Xpetra_Matrix_complex> >(const mxArray* mxa);
385 template RCP<Xpetra_MultiVector_double> loadDataFromMatlab<RCP<Xpetra_MultiVector_double> >(const mxArray* mxa);
386 template RCP<Xpetra_MultiVector_complex> loadDataFromMatlab<RCP<Xpetra_MultiVector_complex> >(const mxArray* mxa);
387 #ifdef HAVE_MUELU_EPETRA
388 template RCP<Epetra_CrsMatrix> loadDataFromMatlab<RCP<Epetra_CrsMatrix> >(const mxArray* mxa);
389 template RCP<Epetra_MultiVector> loadDataFromMatlab<RCP<Epetra_MultiVector> >(const mxArray* mxa);
390 #endif
391 template RCP<MAggregates> loadDataFromMatlab<RCP<MAggregates> >(const mxArray* mxa);
392 template RCP<MAmalInfo> loadDataFromMatlab<RCP<MAmalInfo> >(const mxArray* mxa);
393 
394 template mxArray* saveDataToMatlab(bool& data);
395 template mxArray* saveDataToMatlab(int& data);
396 template mxArray* saveDataToMatlab(double& data);
397 template mxArray* saveDataToMatlab(complex_t& data);
398 template mxArray* saveDataToMatlab(string& data);
408 #ifdef HAVE_MUELU_EPETRA
411 #endif
412 template mxArray* saveDataToMatlab(RCP<MAggregates>& data);
413 template mxArray* saveDataToMatlab(RCP<MAmalInfo>& data);
414 
415 template vector<RCP<MuemexArg> > processNeeds<double>(const Factory* factory, string& needsParam, Level& lvl);
416 template vector<RCP<MuemexArg> > processNeeds<complex_t>(const Factory* factory, string& needsParam, Level& lvl);
417 template void processProvides<double>(vector<RCP<MuemexArg> >& mexOutput, const Factory* factory, string& providesParam, Level& lvl);
418 template void processProvides<complex_t>(vector<RCP<MuemexArg> >& mexOutput, const Factory* factory, string& providesParam, Level& lvl);
419 
420 } // namespace MueLu
421 #endif // HAVE_MUELU_MATLAB
bool isValidMatlabAggregates(const mxArray *mxa)
template mxArray * saveDataToMatlab(bool &data)
std::vector< std::string > tokenizeList(const std::string &params)
mxArray * createMatlabSparse< complex_t >(int numRows, int numCols, int nnz)
bool rewrap_ints
template vector< RCP< MuemexArg > > processNeeds< complex_t >(const Factory *factory, string &needsParam, Level &lvl)
template void processProvides< complex_t >(vector< RCP< MuemexArg > > &mexOutput, const Factory *factory, string &providesParam, Level &lvl)
mxArray * saveAmalInfo(RCP< MAmalInfo > &amalInfo)
std::vector< RCP< MuemexArg > > callMatlab(std::string function, int numOutputs, std::vector< RCP< MuemexArg > > args)
template vector< RCP< MuemexArg > > processNeeds< double >(const Factory *factory, string &needsParam, Level &lvl)
void fillMatlabArray< double >(double *array, const mxArray *mxa, int n)
bool isValidMatlabGraph(const mxArray *mxa)
template string loadDataFromMatlab< string >(const mxArray *mxa)
Teuchos::RCP< Teuchos::ParameterList > getInputParamList()
template void processProvides< double >(vector< RCP< MuemexArg > > &mexOutput, const Factory *factory, string &providesParam, Level &lvl)
struct mxArray_tag mxArray
TEUCHOS_DEPRECATED RCP< T > rcp(T *p, Dealloc_T dealloc, bool owns_mem)
void fillMatlabArray< complex_t >(complex_t *array, const mxArray *mxa, int n)
mxArray * createMatlabSparse< double >(int numRows, int numCols, int nnz)
int * mwIndex_to_int(int N, mwIndex *mwi_array)
template complex_t loadDataFromMatlab< complex_t >(const mxArray *mxa)
template int loadDataFromMatlab< int >(const mxArray *mxa)
template bool loadDataFromMatlab< bool >(const mxArray *mxa)
mxArray * createMatlabMultiVector< complex_t >(int numRows, int numCols)
void callMatlabNoArgs(std::string function)
std::complex< double > complex_t
mxArray * createMatlabMultiVector< double >(int numRows, int numCols)
template double loadDataFromMatlab< double >(const mxArray *mxa)
int mwIndex
Teuchos::RCP< MuemexArg > convertMatlabVar(const mxArray *mxa)