diff --git a/src/theory/bv/theory_bv_type_rules.cpp b/src/theory/bv/theory_bv_type_rules.cpp index ed435448289..10df189c7f3 100644 --- a/src/theory/bv/theory_bv_type_rules.cpp +++ b/src/theory/bv/theory_bv_type_rules.cpp @@ -21,11 +21,48 @@ #include "util/bitvector.h" #include "util/cardinality.h" #include "util/integer.h" +#include "util/rational.h" namespace cvc5::internal { namespace theory { namespace bv { +bool isMaybeBoolean(const TypeNode& tn) +{ + return tn.isBoolean() || tn.isFullyAbstract(); +} + +/** + * Return true if tn is maybe a bit-vector type. Write to errOut if it exists + * and tn is not a maybe bit-vector type. + */ +bool checkMaybeBitVector(const TypeNode& tn, std::ostream* errOut) +{ + if (tn.isMaybeKind(Kind::BITVECTOR_TYPE)) + { + return true; + } + if (errOut) + { + (*errOut) << "expecting a bit-vector term"; + } + return false; +} + +/** + * Ensure that tn is a bit-vector type. + * Note this is equivalent to tn.leastUpperBound(?BitVec). + */ +TypeNode ensureBv(NodeManager* nm, const TypeNode& tn) +{ + if (tn.getKind() == Kind::ABSTRACT_TYPE + && tn.getAbstractedKind() == Kind::ABSTRACT_TYPE) + { + return nm->mkAbstractType(Kind::BITVECTOR_TYPE); + } + return tn; +} + Cardinality CardinalityComputer::computeCardinality(TypeNode type) { Assert(type.getKind() == Kind::BITVECTOR_TYPE); @@ -51,7 +88,11 @@ TypeNode BitVectorConstantTypeRule::computeType(NodeManager* nodeManager, { if (n.getConst().getSize() == 0) { - throw TypeCheckingExceptionPrivate(n, "constant of size 0"); + if (errOut) + { + (*errOut) << "constant of size 0"; + } + return TypeNode::null(); } } return nodeManager->mkBitVectorType(n.getConst().getSize()); @@ -66,25 +107,36 @@ TypeNode BitVectorFixedWidthTypeRule::computeType(NodeManager* nodeManager, bool check, std::ostream* errOut) { - TNode::iterator it = n.begin(); - TypeNode t = (*it).getType(check); - if (check) + TypeNode t; + for (const Node& nc : n) { - if (!t.isBitVector()) + TypeNode tc = nc.getTypeOrNull(); + if (check) + { + if (!checkMaybeBitVector(tc, errOut)) + { + return TypeNode::null(); + } + } + // if first child + if (t.isNull()) { - throw TypeCheckingExceptionPrivate(n, "expecting bit-vector terms"); + t = tc; + continue; } - TNode::iterator it_end = n.end(); - for (++it; it != it_end; ++it) + t = t.leastUpperBound(tc); + if (t.isNull()) { - if ((*it).getType(check) != t) + if (errOut) { - throw TypeCheckingExceptionPrivate( - n, "expecting bit-vector terms of the same width"); + (*errOut) << "expecting comparable bit-vector terms"; } + return TypeNode::null(); } } - return t; + // ensure return is bitvector, e.g. if 2 fully abstract children, return + // ?BitVec. + return ensureBv(nodeManager, t); } TypeNode BitVectorPredicateTypeRule::preComputeType(NodeManager* nm, TNode n) @@ -98,18 +150,21 @@ TypeNode BitVectorPredicateTypeRule::computeType(NodeManager* nodeManager, { if (check) { - TypeNode lhsType = n[0].getType(check); - if (!lhsType.isBitVector()) + TypeNode lhsType = n[0].getTypeOrNull(); + if (!checkMaybeBitVector(lhsType, errOut)) { - throw TypeCheckingExceptionPrivate(n, "expecting bit-vector terms"); + return TypeNode::null(); } if (n.getNumChildren() > 1) { - TypeNode rhsType = n[1].getType(check); - if (lhsType != rhsType) + TypeNode rhsType = n[1].getTypeOrNull(); + if (!lhsType.isComparableTo(rhsType)) { - throw TypeCheckingExceptionPrivate( - n, "expecting bit-vector terms of the same width"); + if (errOut) + { + (*errOut) << "expecting comparable bit-vector terms"; + } + return TypeNode::null(); } } } @@ -127,10 +182,10 @@ TypeNode BitVectorRedTypeRule::computeType(NodeManager* nodeManager, { if (check) { - TypeNode type = n[0].getType(check); - if (!type.isBitVector()) + TypeNode type = n[0].getTypeOrNull(); + if (!checkMaybeBitVector(type, errOut)) { - throw TypeCheckingExceptionPrivate(n, "expecting bit-vector term"); + return TypeNode::null(); } } return nodeManager->mkBitVectorType(1); @@ -147,17 +202,38 @@ TypeNode BitVectorBVPredTypeRule::computeType(NodeManager* nodeManager, { if (check) { - TypeNode lhs = n[0].getType(check); - TypeNode rhs = n[1].getType(check); - if (!lhs.isBitVector() || lhs != rhs) + TypeNode lhs = n[0].getTypeOrNull(); + TypeNode rhs = n[1].getTypeOrNull(); + if (!checkMaybeBitVector(lhs, errOut) || !checkMaybeBitVector(rhs, errOut) + || !lhs.isComparableTo(rhs)) { - throw TypeCheckingExceptionPrivate( - n, "expecting bit-vector terms of the same width"); + if (errOut) + { + (*errOut) << "expecting comparable bit-vector terms"; + } + return TypeNode::null(); } } return nodeManager->mkBitVectorType(1); } +TypeNode BitVectorSizeTypeRule::preComputeType(NodeManager* nm, TNode n) +{ + return nm->integerType(); +} +TypeNode BitVectorSizeTypeRule::computeType(NodeManager* nodeManager, + TNode n, + bool check, + std::ostream* errOut) +{ + TypeNode t = n[0].getTypeOrNull(check); + if (!checkMaybeBitVector(t, errOut)) + { + return TypeNode::null(); + } + return nodeManager->integerType(); +} + TypeNode BitVectorConcatTypeRule::preComputeType(NodeManager* nm, TNode n) { return TypeNode::null(); @@ -168,18 +244,33 @@ TypeNode BitVectorConcatTypeRule::computeType(NodeManager* nodeManager, std::ostream* errOut) { uint32_t size = 0; + bool isAbstract = false; for (const auto& child : n) { - TypeNode t = child.getType(check); + TypeNode t = child.getTypeOrNull(); // NOTE: We're throwing a type-checking exception here even // when check is false, bc if we don't check that the arguments // are bit-vectors the result type will be inaccurate - if (!t.isBitVector()) + if (!checkMaybeBitVector(t, errOut)) + { + return TypeNode::null(); + } + if (isAbstract) + { + continue; + } + else if (t.isAbstract()) { - throw TypeCheckingExceptionPrivate(n, "expecting bit-vector terms"); + isAbstract = true; + continue; } size += t.getBitVectorSize(); } + // if any child is abstract, we are abstract + if (isAbstract) + { + return nodeManager->mkAbstractType(Kind::BITVECTOR_TYPE); + } return nodeManager->mkBitVectorType(size); } @@ -195,10 +286,14 @@ TypeNode BitVectorToBVTypeRule::computeType(NodeManager* nodeManager, { for (const auto& child : n) { - TypeNode t = child.getType(check); - if (!t.isBoolean()) + TypeNode t = child.getTypeOrNull(); + if (!isMaybeBoolean(t)) { - throw TypeCheckingExceptionPrivate(n, "expecting Boolean terms"); + if (errOut) + { + (*errOut) << "expecting Boolean terms"; + } + return TypeNode::null(); } } return nodeManager->mkBitVectorType(n.getNumChildren()); @@ -214,23 +309,28 @@ TypeNode BitVectorITETypeRule::computeType(NodeManager* nodeManager, std::ostream* errOut) { Assert(n.getNumChildren() == 3); - TypeNode thenpart = n[1].getType(check); + TypeNode thenpart = n[1].getTypeOrNull(); + TypeNode elsepart = n[2].getTypeOrNull(); + // like ite, return is the join of the branches + TypeNode retType = thenpart.leastUpperBound(elsepart); if (check) { - TypeNode cond = n[0].getType(check); - if (cond != nodeManager->mkBitVectorType(1)) - { - throw TypeCheckingExceptionPrivate( - n, "expecting condition to be bit-vector term size 1"); - } - TypeNode elsepart = n[2].getType(check); - if (thenpart != elsepart) + TypeNode cond = n[0].getTypeOrNull(); + if (!nodeManager->mkBitVectorType(1).isComparableTo(cond)) { - throw TypeCheckingExceptionPrivate( - n, "expecting then and else parts to have same type"); + if (errOut) + { + (*errOut) << "expecting condition to be comparable with bit-vector " + "term size 1"; + } + return TypeNode::null(); } } - return thenpart; + if (retType.isNull() && errOut) + { + (*errOut) << "expecting then and else parts to have comparable types"; + } + return retType; } TypeNode BitVectorBitOfTypeRule::preComputeType(NodeManager* nm, TNode n) @@ -245,16 +345,19 @@ TypeNode BitVectorBitOfTypeRule::computeType(NodeManager* nodeManager, if (check) { BitVectorBitOf info = n.getOperator().getConst(); - TypeNode t = n[0].getType(check); - - if (!t.isBitVector()) + TypeNode t = n[0].getTypeOrNull(); + if (!checkMaybeBitVector(t, errOut)) { - throw TypeCheckingExceptionPrivate(n, "expecting bit-vector term"); + return TypeNode::null(); } - if (info.d_bitIndex >= t.getBitVectorSize()) + // note this is not checked if the argument has abstract type + if (t.isBitVector() && info.d_bitIndex >= t.getBitVectorSize()) { - throw TypeCheckingExceptionPrivate( - n, "extract index is larger than the bitvector size"); + if (errOut) + { + (*errOut) << "extract index is larger than the bitvector size"; + } + return TypeNode::null(); } } return nodeManager->booleanType(); @@ -277,23 +380,33 @@ TypeNode BitVectorExtractTypeRule::computeType(NodeManager* nodeManager, // type will be illegal if (extractInfo.d_high < extractInfo.d_low) { - throw TypeCheckingExceptionPrivate( - n, "high extract index is smaller than the low extract index"); + if (errOut) + { + (*errOut) << "high extract index is smaller than the low extract index"; + } + return TypeNode::null(); } if (check) { - TypeNode t = n[0].getType(check); - if (!t.isBitVector()) + TypeNode t = n[0].getTypeOrNull(); + if (!checkMaybeBitVector(t, errOut)) { - throw TypeCheckingExceptionPrivate(n, "expecting bit-vector term"); + return TypeNode::null(); } - if (extractInfo.d_high >= t.getBitVectorSize()) + // note this is not checked if the argument has abstract type + if (t.isBitVector() && extractInfo.d_high >= t.getBitVectorSize()) { - throw TypeCheckingExceptionPrivate( - n, "high extract index is bigger than the size of the bit-vector"); + if (errOut) + { + (*errOut) + << "high extract index is bigger than the size of the bit-vector"; + } + return TypeNode::null(); } } + // note that its type is always concrete, even if the argument has abstract + // type return nodeManager->mkBitVectorType(extractInfo.d_high - extractInfo.d_low + 1); } @@ -307,19 +420,30 @@ TypeNode BitVectorRepeatTypeRule::computeType(NodeManager* nodeManager, bool check, std::ostream* errOut) { - TypeNode t = n[0].getType(check); + TypeNode t = n[0].getTypeOrNull(); // NOTE: We're throwing a type-checking exception here even // when check is false, bc if the argument isn't a bit-vector // the result type will be inaccurate - if (!t.isBitVector()) + if (!checkMaybeBitVector(t, errOut)) { - throw TypeCheckingExceptionPrivate(n, "expecting bit-vector term"); + return TypeNode::null(); } uint32_t repeatAmount = n.getOperator().getConst(); if (repeatAmount == 0) { - throw TypeCheckingExceptionPrivate(n, "expecting number of repeats > 0"); + if (errOut) + { + (*errOut) << "expecting number of repeats > 0"; + } + return TypeNode::null(); + } + // if abstract, we don't take into account the repeat amount, instead we + // return ?BitVec. + if (t.isAbstract()) + { + return ensureBv(nodeManager, t); } + Assert(t.isBitVector()); return nodeManager->mkBitVectorType(repeatAmount * t.getBitVectorSize()); } @@ -332,14 +456,19 @@ TypeNode BitVectorExtendTypeRule::computeType(NodeManager* nodeManager, bool check, std::ostream* errOut) { - TypeNode t = n[0].getType(check); + TypeNode t = n[0].getTypeOrNull(); // NOTE: We're throwing a type-checking exception here even // when check is false, bc if the argument isn't a bit-vector // the result type will be inaccurate - if (!t.isBitVector()) + if (!checkMaybeBitVector(t, errOut)) { - throw TypeCheckingExceptionPrivate(n, "expecting bit-vector term"); + return TypeNode::null(); } + if (t.isAbstract()) + { + return ensureBv(nodeManager, t); + } + Assert(t.isBitVector()); uint32_t extendAmount = n.getKind() == Kind::BITVECTOR_SIGN_EXTEND ? n.getOperator().getConst() : n.getOperator().getConst(); @@ -357,10 +486,15 @@ TypeNode BitVectorEagerAtomTypeRule::computeType(NodeManager* nodeManager, { if (check) { - TypeNode lhsType = n[0].getType(check); + TypeNode lhsType = n[0].getTypeOrNull(); + // simple check to Boolean if (!lhsType.isBoolean()) { - throw TypeCheckingExceptionPrivate(n, "expecting boolean term"); + if (errOut) + { + (*errOut) << "expecting boolean term"; + } + return TypeNode::null(); } } return nodeManager->booleanType(); @@ -374,15 +508,15 @@ TypeNode BitVectorAckermanizationUdivTypeRule::preComputeType(NodeManager* nm, TypeNode BitVectorAckermanizationUdivTypeRule::computeType( NodeManager* nodeManager, TNode n, bool check, std::ostream* errOut) { - TypeNode lhsType = n[0].getType(check); + TypeNode lhsType = n[0].getTypeOrNull(); if (check) { - if (!lhsType.isBitVector()) + if (!checkMaybeBitVector(lhsType, errOut)) { - throw TypeCheckingExceptionPrivate(n, "expecting bit-vector term"); + return TypeNode::null(); } } - return lhsType; + return ensureBv(nodeManager, lhsType); } TypeNode BitVectorAckermanizationUremTypeRule::preComputeType(NodeManager* nm, @@ -393,15 +527,15 @@ TypeNode BitVectorAckermanizationUremTypeRule::preComputeType(NodeManager* nm, TypeNode BitVectorAckermanizationUremTypeRule::computeType( NodeManager* nodeManager, TNode n, bool check, std::ostream* errOut) { - TypeNode lhsType = n[0].getType(check); + TypeNode lhsType = n[0].getTypeOrNull(); if (check) { - if (!lhsType.isBitVector()) + if (!checkMaybeBitVector(lhsType, errOut)) { - throw TypeCheckingExceptionPrivate(n, "expecting bit-vector term"); + return TypeNode::null(); } } - return lhsType; + return ensureBv(nodeManager, lhsType); } } // namespace bv diff --git a/src/theory/bv/theory_bv_type_rules.h b/src/theory/bv/theory_bv_type_rules.h index 2bccb3cc83a..bd21decbfbe 100644 --- a/src/theory/bv/theory_bv_type_rules.h +++ b/src/theory/bv/theory_bv_type_rules.h @@ -95,6 +95,16 @@ class BitVectorBVPredTypeRule /* non-parameterized operator kinds */ /* -------------------------------------------------------------------------- */ +class BitVectorSizeTypeRule +{ + public: + static TypeNode preComputeType(NodeManager* nm, TNode n); + static TypeNode computeType(NodeManager* nodeManager, + TNode n, + bool check, + std::ostream* errOut); +}; + class BitVectorConcatTypeRule { public: