diff --git a/palace/fem/libceed/restriction.cpp b/palace/fem/libceed/restriction.cpp index 12d5d655f..61ffb4391 100644 --- a/palace/fem/libceed/restriction.cpp +++ b/palace/fem/libceed/restriction.cpp @@ -102,9 +102,8 @@ void InitNativeRestr(const mfem::FiniteElementSpace &fespace, if (!fe) { int elem_id, face_info; - int FaceNo = fespace.GetMesh()->GetBdrElementFaceIndex(indices[0]); - fespace.GetMesh()->GetBdrElementAdjacentElement(FaceNo, elem_id, face_info); - mfem::Geometry::Type face_geom = fespace.GetMesh()->GetBdrElementGeometry(FaceNo); + fespace.GetMesh()->GetBdrElementAdjacentElement(indices[0], elem_id, face_info); + mfem::Geometry::Type face_geom = fespace.GetMesh()->GetBdrElementGeometry(indices[0]); fe = fespace.GetTraceElement(elem_id, face_geom); face_flg = true; } @@ -163,8 +162,55 @@ void InitNativeRestr(const mfem::FiniteElementSpace &fespace, } else { - int f = fespace.GetMesh()->GetBdrElementFaceIndex(e); - fespace.GetFaceVDofs(f, dofs); + // Get coordinates of face dofs + int elem_id, face_info; + fespace.GetMesh()->GetBdrElementAdjacentElement(e, elem_id, face_info); + mfem::Geometry::Type face_geom = fespace.GetMesh()->GetBdrElementGeometry(e); + face_info = fespace.GetMesh()->EncodeFaceInfo( + fespace.GetMesh()->DecodeFaceInfoLocalIndex(face_info), + mfem::Geometry::GetInverseOrientation( + face_geom, fespace.GetMesh()->DecodeFaceInfoOrientation(face_info)) + ); + mfem::IntegrationPointTransformation Loc1; + fespace.GetMesh()->GetLocalFaceTransformation(fespace.GetMesh()->GetBdrElementType(e), + fespace.GetMesh()->GetElementType(elem_id), + Loc1.Transf, face_info); + const mfem::FiniteElement *face_el = fespace.GetTraceElement(elem_id, face_geom); + MFEM_VERIFY(dynamic_cast(face_el), + "Mesh requires nodal Finite Element."); + mfem::IntegrationRule face_ir(face_el->GetDof()); + Loc1.Transf.ElementNo = elem_id; + Loc1.Transf.mesh = fespace.GetMesh(); + Loc1.Transf.ElementType = mfem::ElementTransformation::ELEMENT; + Loc1.Transform(face_el->GetNodes(), face_ir); + mfem::IsoparametricTransformation face_tr; + face_tr.ElementNo = e; + face_tr.ElementType = mfem::ElementTransformation::BDR_ELEMENT; + face_tr.mesh = fespace.GetMesh(); + face_tr.Attribute = fespace.GetMesh()->GetBdrAttribute(e); + mfem::DenseMatrix &face_pm = face_tr.GetPointMat(); // dim x dof + face_tr.Reset(); + fespace.GetMesh()->GetNodes()->GetVectorValues(Loc1.Transf, face_ir, face_pm); + + // Get coordinates of element dofs + mfem::DenseMatrix elem_pm; + const mfem::FiniteElement *fe_elem = fespace.GetFE(elem_id); + const mfem::IntegrationRule &fe_elem_nodes = fe_elem->GetNodes(); + mfem::ElementTransformation *T = fespace.GetMesh()->GetElementTransformation(elem_id); + T->Transform(fe_elem_nodes, elem_pm); + + // Find the dofs + mfem::real_t tol = 1E-8; + mfem::Array elem_dofs; + fespace.GetElementDofs(elem_id, elem_dofs, dof_trans); // TODO: Check if passing dof_trans is OK + dofs.SetSize(P); + for (int l=0; l< P; l++) { + for (int m=0; m< elem_pm.Width(); m++) { + if (fabs(face_pm(0, l) - elem_pm(0, m)) < tol && fabs(face_pm(1, l) - elem_pm(1, m)) < tol && fabs(face_pm(2, l) - elem_pm(2, m)) < tol) { // TODO: make this robust + dofs[l] = elem_dofs[m]; + } + } + } } } else