diff --git a/Makefile.am b/Makefile.am index e18ad985c87..25b5aa0223a 100644 --- a/Makefile.am +++ b/Makefile.am @@ -340,6 +340,7 @@ libredex_la_SOURCES = \ service/type-analysis/TypeAnalysisRuntimeAssert.cpp \ service/type-analysis/ResolveMethodRefs.cpp \ service/type-string-rewriter/TypeStringRewriter.cpp \ + service/wrapped-primitives/WrappedPrimitives.cpp \ shared/DexDefs.cpp \ shared/DexEncoding.cpp \ shared/file-utils.cpp \ @@ -526,7 +527,7 @@ libopt_la_SOURCES = \ opt/virtual_merging/DedupVirtualMethods.cpp \ opt/virtual_merging/VirtualMerging.cpp \ opt/virtual_scope/MethodDevirtualizationPass.cpp \ - opt/wrapped-primitives/WrappedPrimitives.cpp + opt/wrapped-primitives/WrappedPrimitivesPass.cpp libopt_la_LIBADD = \ libredex.la \ diff --git a/Makefile.inc b/Makefile.inc index 11d9fa2e51d..6cdad4779a8 100644 --- a/Makefile.inc +++ b/Makefile.inc @@ -132,6 +132,7 @@ COMMON_INCLUDES = \ -I$(top_srcdir)/service/switch-partitioning \ -I$(top_srcdir)/service/type-analysis \ -I$(top_srcdir)/service/type-string-rewriter \ + -I$(top_srcdir)/service/wrapped-primitives \ -I$(top_srcdir)/shared \ -I$(top_srcdir)/sparta/include \ -I$(top_srcdir)/tools/common \ diff --git a/opt/constant-propagation/IPConstantPropagation.cpp b/opt/constant-propagation/IPConstantPropagation.cpp index 73cae3fbf11..f4111c964c1 100644 --- a/opt/constant-propagation/IPConstantPropagation.cpp +++ b/opt/constant-propagation/IPConstantPropagation.cpp @@ -21,6 +21,7 @@ #include "Timer.h" #include "Trace.h" #include "Walkers.h" +#include "WrappedPrimitives.h" namespace mog = method_override_graph; @@ -260,6 +261,12 @@ void PassImpl::optimize( code.cfg(), &xstores, method->get_class()); + // If configured, plug in IPCP state to do additional transforms + // (API unwrapping to primitives for known ObjectWithImmutAttr + // instances). + wrapped_primitives::optimize_method(type_system, ipa->fp_iter, + fp_iter.get_whole_program_state(), + method, code.cfg()); return tf.get_stats(); } }); diff --git a/opt/wrapped-primitives/WrappedPrimitives.cpp b/opt/wrapped-primitives/WrappedPrimitives.cpp deleted file mode 100644 index b01553cbe55..00000000000 --- a/opt/wrapped-primitives/WrappedPrimitives.cpp +++ /dev/null @@ -1,668 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "WrappedPrimitives.h" - -#include - -#include "CFGMutation.h" -#include "ConstantPropagationAnalysis.h" -#include "ConstantPropagationState.h" -#include "ConstantPropagationWholeProgramState.h" -#include "ConstructorParams.h" -#include "DexUtil.h" -#include "InitDeps.h" -#include "PassManager.h" -#include "Resolver.h" -#include "Show.h" -#include "Trace.h" -#include "Walkers.h" -#include "WorkQueue.h" - -namespace cp = constant_propagation; -namespace wp = wrapped_primitives; - -namespace { -// Check assumptions about the wrapper class's hierarchy. -void validate_wrapper_type(DexType* type) { - auto cls = type_class(type); - always_assert(cls != nullptr); - always_assert_log(cls->get_interfaces()->empty(), - "Wrapper type %s should not implement interfaces", - SHOW(type)); - auto super_cls = cls->get_super_class(); - always_assert_log(super_cls == type::java_lang_Object(), - "Wrapper type %s should inherit from Object; got %s", - SHOW(type), SHOW(super_cls)); -} - -void validate_api_mapping(DexMethodRef* from, DexMethodRef* to) { - // Simple validation for now; more involved use cases need to be added later. - always_assert_log( - from->get_class() == to->get_class(), - "Unable to map API from class %s to %s - they are expected to match", - SHOW(from->get_class()), SHOW(to->get_class())); -} - -// A wrapped primitive is assumed to be represented by the only final primitive -// field in the wrapper class. -DexType* get_wrapped_final_field_type(DexType* type) { - auto cls = type_class(type); - always_assert_log(cls != nullptr, "Spec class %s not found", SHOW(type)); - std::vector candidates; - for (auto& f : cls->get_ifields()) { - if (is_final(f) && type::is_primitive(f->get_type())) { - candidates.emplace_back(f); - } - } - always_assert_log(candidates.size() == 1, - "Expected 1 final field of primitive type in class %s", - SHOW(cls)); - return candidates.at(0)->get_type(); -} - -IROpcode sget_op_for_primitive(DexType* type) { - always_assert(type::is_primitive(type)); - if (type::is_boolean(type)) { - return OPCODE_SGET_BOOLEAN; - } else if (type::is_byte(type)) { - return OPCODE_SGET_BYTE; - } else if (type::is_char(type)) { - return OPCODE_SGET_CHAR; - } else if (type::is_short(type)) { - return OPCODE_SGET_SHORT; - } else if (type::is_int(type) || type::is_float(type)) { - return OPCODE_SGET; - } else { - return OPCODE_SGET_WIDE; - } -} - -IROpcode move_op_for_primitive(DexType* type) { - always_assert(type::is_primitive(type)); - if (type::is_wide_type(type)) { - return OPCODE_MOVE_WIDE; - } else { - return OPCODE_MOVE; - } -} - -IROpcode move_result_pseudo_op_for_primitive(DexType* type) { - always_assert(type::is_primitive(type)); - if (type::is_wide_type(type)) { - return IOPCODE_MOVE_RESULT_PSEUDO_WIDE; - } else { - return IOPCODE_MOVE_RESULT_PSEUDO; - } -} -} // namespace - -void WrappedPrimitivesPass::bind_config() { - std::vector wrappers; - bind("wrappers", {}, wrappers); - for (auto it = wrappers.begin(); it != wrappers.end(); ++it) { - const auto& value = *it; - always_assert_log(value.isObject(), - "Wrong specification: spec in array not an object."); - JsonWrapper json_obj = JsonWrapper(value); - wp::Spec spec; - std::string wrapper_desc; - json_obj.get("wrapper", "", wrapper_desc); - spec.wrapper = DexType::get_type(wrapper_desc); - always_assert_log(spec.wrapper != nullptr, "Type %s does not exist", - wrapper_desc.c_str()); - // Ensure the wrapper type matches expectations by the pass. - validate_wrapper_type(spec.wrapper); - spec.primitive = get_wrapped_final_field_type(spec.wrapper); - - // Unpack an array of objects, each object is just a 1 key/value to map an - // API using the wrapper type to the corresponding API of primitive type. - Json::Value allowed_invokes_array; - json_obj.get("allowed_invokes", Json::Value(), allowed_invokes_array); - always_assert_log( - allowed_invokes_array.isArray(), - "Wrong specification: allowed_invokes must be an array of objects."); - for (auto& obj : allowed_invokes_array) { - always_assert_log( - obj.isObject(), - "Wrong specification: allowed_invokes must be an array of objects."); - auto members = obj.getMemberNames(); - always_assert_log( - members.size() == 1, - "Wrong specification: allowed invoke object should be just 1 mapping " - "of method ref string to method ref string."); - auto api = members.at(0); - TRACE(WP, 2, "Checking for API '%s'", api.c_str()); - auto wrapped_api = DexMethod::get_method(api); - always_assert_log(wrapped_api != nullptr, "Method %s does not exist", - api.c_str()); - std::string unwrapped_api_desc; - JsonWrapper jobj = JsonWrapper(obj); - jobj.get(api.c_str(), "", unwrapped_api_desc); - always_assert_log(!unwrapped_api_desc.empty(), "empty!"); - TRACE(WP, 2, "Checking for unwrapped API '%s'", - unwrapped_api_desc.c_str()); - auto unwrapped_api = DexMethod::get_method(unwrapped_api_desc); - always_assert_log(unwrapped_api != nullptr, "Method %s does not exist", - unwrapped_api_desc.c_str()); - - // Make sure this API mapping is not obviously wrong up front. - validate_api_mapping(wrapped_api, unwrapped_api); - spec.allowed_invokes.emplace(wrapped_api, unwrapped_api); - TRACE(WP, 2, "Allowed API call %s -> %s", SHOW(wrapped_api), - SHOW(unwrapped_api)); - } - m_wrapper_specs.emplace_back(spec); - } - trait(Traits::Pass::unique, true); -} - -namespace { -bool has_static_final_wrapper_fields( - const std::unordered_map& wrapper_types, - DexClass* cls) { - for (auto& f : cls->get_sfields()) { - if (is_final(f) && wrapper_types.count(f->get_type()) > 0) { - return true; - } - } - return false; -} - -class ClinitMethodAnalysis : public wp::MethodAnalysis { - public: - ClinitMethodAnalysis( - const std::unordered_map& wrapper_types, - wp::PassState* pass_state, - DexClass* cls, - DexMethod* method) - : wp::MethodAnalysis(wrapper_types, pass_state, cls, method) {} - - void post_analyze() override { - // Construct the representation of all fields that were understood and set - // by the clinit. - std::unordered_map known_fields; - auto& cfg = get_cfg(); - auto intra_cp = get_fixpoint_iterator(); - auto exit_env = intra_cp->get_exit_state_at(cfg.exit_block()); - - m_pass_state->whole_program_state.collect_static_finals( - m_cls, exit_env.get_field_environment()); - for (auto f : m_cls->get_sfields()) { - if (m_wrapper_types.count(f->get_type()) > 0) { - TRACE(WP, 2, "Checking field %s", SHOW(f)); - auto field_value = exit_env.get(f); - auto maybe_constant = extract_object_attr_value(field_value); - if (maybe_constant != boost::none) { - auto constant = *maybe_constant; - TRACE(WP, - 2, - " ==> Field %s is a known object with constant value %" PRId64, - SHOW(f), - constant); - known_fields.emplace(f, constant); - } - } - } - - // Even for understood field values, avoid emitting nodes for fields that - // could be written to via different instructions/instances. Simplifies - // later validation logic. - std::unordered_set visited_fields; - for (auto* block : cfg.blocks()) { - for (auto& mie : InstructionIterable(block)) { - auto insn = mie.insn; - if (insn->opcode() == OPCODE_SPUT_OBJECT) { - auto field_def = insn->get_field()->as_def(); - if (field_def != nullptr) { - auto pair = visited_fields.emplace(field_def); - if (!pair.second) { - known_fields.erase(field_def); - TRACE(WP, 2, - " ==> Field %s written from multiple instructions; will " - "not consider", - SHOW(field_def)); - } - } - } - } - } - // Actual creation of nodes. - for (auto* block : cfg.blocks()) { - for (auto& mie : InstructionIterable(block)) { - auto insn = mie.insn; - if (insn->opcode() == OPCODE_SPUT_OBJECT) { - auto field_def = insn->get_field()->as_def(); - if (field_def != nullptr && known_fields.count(field_def) > 0) { - // Emit a representation of the instructions that created the object - // in this field. - auto defs = m_live_ranges->use_def_chains->at({insn, 0}); - TRACE(WP, 2, " %s -> %zu def(s)", SHOW(mie), defs.size()); - if (defs.size() == 1) { - auto def_insn = *defs.begin(); - // TODO: Is there any practical way to trigger this assert to fire - // for an understood value by collect_static_finals?? - always_assert_log(def_insn->opcode() == OPCODE_NEW_INSTANCE || - def_insn->opcode() == OPCODE_SGET_OBJECT, - "Unsupported instantiation %s", - SHOW(def_insn)); - if (def_insn->opcode() == OPCODE_NEW_INSTANCE) { - auto constant = known_fields.at(field_def); - emit_new_instance_node(constant, def_insn, field_def, insn); - } else { - emit_sget_node(def_insn, field_def, insn); - } - } - } - } - } - } - } -}; - -void analyze_clinit(const std::unordered_map& wrapper_types, - wp::PassState* pass_state, - DexClass* cls, - DexMethod* clinit) { - // Check if this method could be relevant before analyzing. - if (!has_static_final_wrapper_fields(wrapper_types, cls)) { - return; - } - - using CombinedClinitAnalyzer = InstructionAnalyzerCombiner< - cp::ClinitFieldAnalyzer, cp::WholeProgramAwareAnalyzer, - cp::ImmutableAttributeAnalyzer, cp::StaticFinalFieldAnalyzer, - cp::PrimitiveAnalyzer>; - - cp::WholeProgramStateAccessor wps_accessor(pass_state->whole_program_state); - - ClinitMethodAnalysis method_analysis(wrapper_types, pass_state, cls, clinit); - method_analysis.run(CombinedClinitAnalyzer(clinit->get_class(), - &wps_accessor, - &pass_state->attr_analyzer_state, - nullptr, - nullptr)); -} - -class FurtherMethodAnalysis : public wp::MethodAnalysis { - public: - FurtherMethodAnalysis( - const std::unordered_map& wrapper_types, - wp::PassState* pass_state, - DexClass* cls, - DexMethod* method) - : wp::MethodAnalysis(wrapper_types, pass_state, cls, method) {} - - void post_analyze() override { - // Continue building the representation of uses of all instances and fields, - // and their immediate uses. - auto& cfg = get_cfg(); - for (auto* block : cfg.blocks()) { - for (auto& mie : InstructionIterable(block)) { - auto insn = mie.insn; - if (insn->opcode() == OPCODE_SGET_OBJECT) { - auto field_def = - resolve_field(insn->get_field(), FieldSearch::Static); - if (field_def != nullptr && - m_pass_state->sfield_to_node.count(field_def) > 0) { - std::lock_guard lock(m_pass_state->modifications_mtx); - wp::Usage usage{insn, m_method}; - auto sget_node = std::make_unique(); - sget_node->item = usage; - // Find all users of the sget, add edges - attach_usage_nodes(sget_node); - // Then, connect the sget to pre-existing tree. - auto existing_node_ptr = m_pass_state->sfield_to_node.at(field_def); - existing_node_ptr->add_edge(std::move(sget_node)); - } - } - } - } - } -}; - -void analyze_method(const std::unordered_map& wrapper_types, - wp::PassState* pass_state, - DexClass* cls, - DexMethod* m) { - - using CombinedAnalyzer = InstructionAnalyzerCombiner< - cp::WholeProgramAwareAnalyzer, cp::ImmutableAttributeAnalyzer, - cp::StaticFinalFieldAnalyzer, cp::PrimitiveAnalyzer>; - - cp::WholeProgramStateAccessor wps_accessor(pass_state->whole_program_state); - FurtherMethodAnalysis method_analysis(wrapper_types, pass_state, cls, m); - method_analysis.run(CombinedAnalyzer( - &wps_accessor, &pass_state->attr_analyzer_state, nullptr, nullptr)); -} - -void transform_usage(const wp::Source& source, - const std::unique_ptr& ptr, - const wp::Spec& spec, - PassManager& mgr) { - auto usage = std::get(ptr->item); - auto& cfg = usage.method->get_code()->cfg(); - auto usage_it = cfg.find_insn(usage.insn); - cfg::CFGMutation mutation(cfg); - - auto get_insn_field = [&]() { - auto def = resolve_field(usage.insn->get_field(), FieldSearch::Static); - always_assert(def != nullptr && is_final(def) && is_static(def)); - return def; - }; - - auto op = usage.insn->opcode(); - if (op == OPCODE_SPUT_OBJECT) { - // Swap the field of wrapper type to the type of primitive in the original - // class. - auto def = get_insn_field(); - DexFieldSpec primitive_spec(def->get_class(), def->get_name(), - spec.primitive); - def->change(primitive_spec); - auto encoded_value = DexEncodedValue::zero_for_type(spec.primitive); - encoded_value->value(source.primitive_value); - def->set_value(std::move(encoded_value)); - TRACE(WP, 1, "Edited field spec: %s", SHOW(def)); - mgr.incr_metric("fields_changed", 1); - // Remove the sput-object; the encoded value will take its place. - mutation.remove(usage_it); - } else if (op == OPCODE_SGET_OBJECT) { - auto sget = new IRInstruction(sget_op_for_primitive(spec.primitive)); - auto def = get_insn_field(); - auto new_ref = - DexField::get_field(def->get_class(), def->get_name(), spec.primitive); - sget->set_field(new_ref); - // Update the following instruction too if it exists. - auto move_pseudo_it = cfg.move_result_of(usage_it); - if (move_pseudo_it.is_end()) { - mutation.replace(usage_it, {sget}); - } else { - auto move_pseudo = new IRInstruction( - move_result_pseudo_op_for_primitive(spec.primitive)); - move_pseudo->set_dest(move_pseudo_it->insn->dest()); - mutation.replace(usage_it, {sget, move_pseudo}); - } - mgr.incr_metric("sgets_changed", 1); - } else if (op == OPCODE_MOVE_OBJECT) { - auto move = new IRInstruction(move_op_for_primitive(spec.primitive)); - move->set_src(0, usage.insn->src(0)); - move->set_dest(usage.insn->dest()); - mutation.replace(usage_it, {move}); - } else { - always_assert_log(opcode::is_an_invoke(op), - "Unsupported instruction for patching: %s", - SHOW(usage.insn)); - // TODO: as capabilities of this pass expand, this logic may need to swap - // the opcode here too. For now, the types are asserted to match up front - // (which is simpler). - auto ref = usage.insn->get_method(); - auto search = spec.allowed_invokes.find(ref); - always_assert_log(search != spec.allowed_invokes.end(), - "Unconfigured invoke to %s was allowed as a valid usage", - SHOW(ref)); - auto unwrapped_ref = search->second; - usage.insn->set_method(unwrapped_ref); - mgr.incr_metric("invokes_changed", 1); - } - mutation.flush(); - // Continue making edits down the tree. - for (auto& next : ptr->edges) { - transform_usage(source, next, spec, mgr); - } -} - -void transform_node(const std::unique_ptr& source_ptr, - const wp::Spec& spec, - PassManager& mgr) { - auto source = std::get(source_ptr->item); - for (auto& ptr : source_ptr->edges) { - transform_usage(source, ptr, spec, mgr); - } -} - -// Checks the rstate of the method associated with node. Validation that allows/ -// disallows transforms should respect this. -bool no_optimizations(const wp::Spec& spec, - const std::unique_ptr& ptr) { - auto method = ptr->get_method(); - if (method->rstate.no_optimizations()) { - TRACE(WP, - 2, - "[%s] Unsupported method %s via rstate", - SHOW(spec.wrapper), - SHOW(method)); - return true; - } - return false; -} - -bool validate_usage(const std::unique_ptr& ptr, - const wp::Spec& spec, - PassManager& mgr) { - always_assert(ptr->is_usage()); - if (no_optimizations(spec, ptr)) { - return false; - } - auto usage = std::get(ptr->item); - auto log_unsupported = [&]() { - TRACE(WP, - 2, - "[%s] Unsupported usage %s from method %s", - SHOW(spec.wrapper), - SHOW(usage.insn), - SHOW(usage.method)); - mgr.incr_metric("unsupported_usage", 1); - }; - auto op = usage.insn->opcode(); - if (op == OPCODE_SPUT_OBJECT || op == OPCODE_SGET_OBJECT) { - auto def = resolve_field(usage.insn->get_field(), FieldSearch::Static); - if (def == nullptr || def->get_type() != spec.wrapper || !is_final(def) || - !def->rstate.can_delete()) { - log_unsupported(); - return false; - } - } else if (opcode::is_an_invoke(op)) { - // Check for invocations to configured method(s) - if (spec.allowed_invokes.count(usage.insn->get_method()) == 0) { - log_unsupported(); - return false; - } - } else if (op == OPCODE_MOVE_OBJECT) { - // Support this automatically; patching this to change to primitve should - // be fine. No logic here intentionally. - } else { - log_unsupported(); - return false; - } - for (auto& next : ptr->edges) { - if (!validate_usage(next, spec, mgr)) { - return false; - } - } - return true; -} - -// Returns true if the given node and all its downstream usages are simple -// enough to be transformed by this pass. Increments metrics for unsupported -// usages. -bool validate_node(const std::unique_ptr& source_ptr, - const wp::Spec& spec, - PassManager& mgr) { - always_assert(source_ptr->is_source()); - if (no_optimizations(spec, source_ptr)) { - return false; - } - for (auto& ptr : source_ptr->edges) { - if (!validate_usage(ptr, spec, mgr)) { - return false; - } - } - return true; -} - -void print_edge(const size_t indent, const std::unique_ptr& ptr) { - always_assert(ptr->is_usage()); - std::string indent_str(indent, ' '); - auto& usage = std::get(ptr->item); - TRACE(WP, - 1, - "%s-> USAGE@%p { %s (%s) }", - indent_str.c_str(), - ptr.get(), - SHOW(usage.method), - SHOW(usage.insn)); - for (auto& next : ptr->edges) { - print_edge(indent + 2, next); - } -} - -void print_node(std::unique_ptr& node, bool edges = true) { - always_assert(node->is_source()); - auto& source = std::get(node->item); - TRACE(WP, - 1, - "NODE@%p { %s (%s %s) value = %" PRId64 " }", - node.get(), - SHOW(source.method), - SHOW(source.new_instance), - SHOW(source.init), - source.primitive_value); - if (edges) { - for (auto& ptr : node->edges) { - print_edge(2, ptr); - } - } -} -} // namespace - -void WrappedPrimitivesPass::eval_pass(DexStoresVector& stores, - ConfigFiles& conf, - PassManager&) { - for (auto& spec : m_wrapper_specs) { - for (auto&& [from, to] : spec.allowed_invokes) { - auto def = to->as_def(); - if (def != nullptr && def->rstate.can_delete()) { - TRACE(WP, 2, "Setting %s as root", SHOW(def)); - def->rstate.set_root(); - m_marked_root_methods.emplace(def); - auto cls = type_class(def->get_class()); - if (cls->rstate.can_delete()) { - TRACE(WP, 2, "Setting %s as root", SHOW(cls)); - cls->rstate.set_root(); - m_marked_root_classes.emplace(cls); - } - } - } - for (auto& method : spec.wrapper_type_constructors()) { - if (!method->rstate.dont_inline()) { - method->rstate.set_dont_inline(); - TRACE(WP, 2, "Disallowing inlining for %s", SHOW(method)); - } - } - } -} - -// Undoes the changes made by eval_pass -void WrappedPrimitivesPass::unset_roots() { - for (auto& def : m_marked_root_methods) { - TRACE(WP, 2, "Unsetting %s as root", SHOW(def)); - def->rstate.unset_root(); - } - for (auto& cls : m_marked_root_classes) { - TRACE(WP, 2, "Unsetting %s as root", SHOW(cls)); - cls->rstate.unset_root(); - } -} - -void WrappedPrimitivesPass::run_pass(DexStoresVector& stores, - ConfigFiles& /* unused */, - PassManager& mgr) { - std::unordered_map wrapper_types; - wp::PassState pass_state; - for (auto& spec : m_wrapper_specs) { - TRACE(WP, - 1, - "Will check for wrapper type %s with supported methods:", - SHOW(spec.wrapper)); - for (auto&& [from, to] : spec.allowed_invokes) { - TRACE(WP, 1, " %s", SHOW(from)); - } - auto wrapper_cls = type_class(spec.wrapper); - always_assert(wrapper_cls != nullptr); - wrapper_types.emplace(spec.wrapper, spec); - cp::immutable_state::analyze_constructors({wrapper_cls}, - &pass_state.attr_analyzer_state); - } - - // First phase: analyze clinit methods to find static final field values. - // Begin assembling a tree of construction of the wrapper types, their - // immediate usages, and their writes and reads to static final fields. - auto scope = build_class_scope(stores); - size_t possible_cycles{0}; - auto sorted_scope = - init_deps::reverse_tsort_by_clinit_deps(scope, possible_cycles); - for (auto cls : sorted_scope) { - if (cls->is_external()) { - continue; - } - auto clinit = cls->get_clinit(); - if (clinit != nullptr && clinit->get_code() != nullptr) { - analyze_clinit(wrapper_types, &pass_state, cls, clinit); - } - } - - // Continue analyzing the scope, find all uses of static final fields from the - // initial phase. Continue building the tree of usages. - InsertOnlyConcurrentSet further_analysis_set; - walk::parallel::opcodes(scope, [&](DexMethod* m, IRInstruction* insn) { - if (insn->opcode() == OPCODE_SGET_OBJECT) { - auto ref = insn->get_field(); - auto def = resolve_field(ref, FieldSearch::Static); - if (def != nullptr && is_final(def) && is_static(def) && - wrapper_types.count(def->get_type()) > 0) { - further_analysis_set.insert(m); - } - } - }); - workqueue_run( - [&](DexMethod* m) { - analyze_method(wrapper_types, &pass_state, type_class(m->get_class()), - m); - }, - further_analysis_set, - traceEnabled(WP, 9) ? 1 : redex_parallel::default_num_threads()); - - TRACE(WP, 1, "\nDumping nodes:"); - for (auto& node : pass_state.forest.nodes) { - print_node(node); - TRACE(WP, 1, ""); - } - TRACE(WP, 1, "*************************************************************"); - - // For each understood creation of a wrapper type, check if all usages fit - // into a very narrow definition of supported uses that could easily be - // swapped out for its wrapped primitive type. - for (auto& ptr : pass_state.forest.nodes) { - auto source = std::get(ptr->item); - auto spec = wrapper_types.at(source.new_instance->get_type()); - if (validate_node(ptr, spec, mgr)) { - TRACE(WP, 1, "SUPPORTED:"); - print_node(ptr); - transform_node(ptr, spec, mgr); - } else { - TRACE(WP, 1, "Not supported:"); - print_node(ptr, false); - } - TRACE(WP, 1, ""); - } - - // Lastly, undo any reachability modifications that were applied during - // eval_pass. - unset_roots(); -} - -static WrappedPrimitivesPass s_pass; diff --git a/opt/wrapped-primitives/WrappedPrimitives.h b/opt/wrapped-primitives/WrappedPrimitives.h deleted file mode 100644 index 3d96e596df0..00000000000 --- a/opt/wrapped-primitives/WrappedPrimitives.h +++ /dev/null @@ -1,367 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. - */ - -#pragma once - -#include -#include - -#include "ConstantPropagationAnalysis.h" -#include "ConstantPropagationWholeProgramState.h" -#include "DexClass.h" -#include "IRInstruction.h" -#include "Lazy.h" -#include "LiveRange.h" -#include "Pass.h" -#include "Trace.h" - -namespace wrapped_primitives { -// A config driven spec describing wrapper classes to look for, each of which is -// asserted to have 1 final field of some primitive type. Beyond this, -// assumptions also include: -// 1) A constructor taking 1 argument which is the primitive it wraps. -// 2) Wrapper class extends java.lang.Object and does not implement interfaces. -// -// Wrapper class instances that can effectively be "unboxed" by this pass must -// conform to a very narrow set of usages. Currently, supported uses are: -// - Wrapper class can be instantiated with a known constant (known means -// intraprocedural constant propagation can easily figure it out). -// - Wrapper class instances can be written to static final fields. -// - Wrapper class instances can be retrieved from static final fields. -// - Wrapper class instances can be an argument to a set of configured "allowed -// - invokes" i.e. method refs that they can be passed to. -// -// Finally, the input program must honor guarantees about the allowed method -// invocations. For the output program to type check properly, it must be -// explicitly listed for every allowed API taking the wrapper class, what is the -// corresponding primitive API that should be swapped in. It is up to the author -// of the input program to ensure that this works in practice, otherwise Redex -// is free to fail in whatever way it chooses (i.e. fail the build or optimize -// no wrapper types). -// -// EXAMPLE: -// "LFoo;.a:(LMyLong;)V" is an allowed invoke, the config should map this to -// something like "LFoo;.a:(J)V" which will also need to exist in the input -// program. This is the simplest form. If however, the allowed invoke maps to an -// API on a different type, say from an interface method to a method on the -// interface's underlying implenentor, check-cast instructions may need to be -// inserted to make this work. It's up to the program's authors to ensure this -// ends up as a working app (and we may fail the build otherwise, or insert -// casts that would fail at runtime if things are misconfigured). -struct Spec { - DexType* wrapper{nullptr}; - DexType* primitive{nullptr}; - std::map allowed_invokes; - - std::vector wrapper_type_constructors() { - auto cls = type_class(wrapper); - return cls->get_ctors(); - } -}; - -// Details pertaining to an understood instantiation of a wrapper class with a -// known primitive given to its constructor. -struct Source { - IRInstruction* new_instance; - IRInstruction* init; - DexMethod* method; - int64_t primitive_value; -}; - -// A point in the code which a wrapper class is being used (beyond it being -// instantiation). -struct Usage { - IRInstruction* insn; - DexMethod* method; -}; - -// Represents a tree of instantiation (Source) to many Usages (which can have) -// their own uses. Equality checks are here to let this be built up in rounds. -struct Node { - std::variant item; - std::vector> edges; - std::unordered_set seen; - - bool is_source() { return std::holds_alternative(item); } - bool is_usage() { return std::holds_alternative(item); } - - // For use-def analysis, the instruction that could be followed up by uses - IRInstruction* get_def_instruction() { - if (is_source()) { - return std::get(item).new_instance; - } else { - return std::get(item).insn; - } - } - - DexMethod* get_method() { - if (is_source()) { - return std::get(item).method; - } else { - return std::get(item).method; - } - } - - void add_edge(std::unique_ptr node) { - always_assert(node->is_usage()); - auto insn = std::get(node->item).insn; - if (seen.count(insn) == 0) { - seen.emplace(insn); - edges.emplace_back(std::move(node)); - } - } -}; - -// Allow for Nodes to be built up sequentially in rounds, keeping track of only -// newly seen things. -struct Forest { - std::vector> nodes; - std::unordered_set seen; - - void add_node(std::unique_ptr node) { - always_assert(node->is_source()); - auto insn = std::get(node->item).new_instance; - if (seen.count(insn) == 0) { - seen.emplace(insn); - nodes.emplace_back(std::move(node)); - } - } -}; - -// Global state of the pass as it analyzes static fields and their usages. -struct PassState { - Forest forest; - std::unordered_map sfield_to_node; - constant_propagation::WholeProgramState whole_program_state; - constant_propagation::ImmutableAttributeAnalyzerState attr_analyzer_state; - // For modifications to the tree of source/usages. - std::mutex modifications_mtx; -}; - -class MethodAnalysis { - public: - MethodAnalysis(const std::unordered_map& wrapper_types, - PassState* pass_state, - DexClass* cls, - DexMethod* method) - : m_wrapper_types(wrapper_types), - m_pass_state(pass_state), - m_cls(cls), - m_method(method), - m_live_ranges([&]() { - return std::make_unique(get_cfg()); - }) { - auto& cfg = get_cfg(); - cfg.calculate_exit_block(); - } - - virtual ~MethodAnalysis() {} - - cfg::ControlFlowGraph& get_cfg() { return m_method->get_code()->cfg(); } - - // Checks if the value is a known ObjectWithImmutAttr with a single known - // attribute value. Makes assumptions that there is only 1, as is consistent - // with the other assumptions in the pass. - boost::optional extract_object_attr_value( - const ConstantValue& value) { - auto obj_or_none = value.maybe_get(); - if (obj_or_none != boost::none && - obj_or_none->get_constant() != boost::none) { - auto object = *obj_or_none->get_constant(); - always_assert(object.attributes.size() == 1); - auto signed_value = - object.attributes.front().value.maybe_get(); - if (signed_value != boost::none && - signed_value.value().get_constant() != boost::none) { - return *signed_value.value().get_constant(); - } else { - TRACE(WP, 2, " No SignedConstantDomain value"); - } - } else { - TRACE(WP, 2, " Not a known ObjectWithImmutAttrDomain"); - } - - return boost::none; - } - - // For a def instruction (asserted to be a new-instance), find the usage that - // invokes the constructor. Asserts there is only 1. - IRInstruction* find_invoke_ctor(IRInstruction* new_instance) { - IRInstruction* invoke_ctor{nullptr}; - auto& uses = m_live_ranges->def_use_chains->at(new_instance); - for (auto u : uses) { - if (u.insn->opcode() == OPCODE_INVOKE_DIRECT && - method::is_init(u.insn->get_method())) { - if (u.insn->get_method()->get_class() == new_instance->get_type()) { - always_assert(invoke_ctor == nullptr); - invoke_ctor = u.insn; - } - } - } - return invoke_ctor; - } - - // For information about the instantiation or get of a wrapped type, attach - // the node to the pass state's representation, along with nodes for all - // immediate uses of the def. - void attach_usage_nodes( - std::unique_ptr& def_node, - const std::unordered_set& exceptions) { - auto def = def_node->get_def_instruction(); - auto& uses = m_live_ranges->def_use_chains->at(def); - TRACE(WP, 2, "%s has %zu use(s)", SHOW(def), uses.size()); - // Make nodes for the use(s) - for (auto u : uses) { - if (exceptions.count(u.insn) > 0) { - continue; - } - Usage usage{u.insn, m_method}; - auto usage_node = std::make_unique(); - usage_node->item = usage; - def_node->add_edge(std::move(usage_node)); - } - } - - void attach_usage_nodes(std::unique_ptr& def_node) { - attach_usage_nodes(def_node, {}); - } - - // Keeps track of global state for the node of a field, so that further usages - // can be connected to the pass state's representation. - void store_sput_node_pointer(std::unique_ptr& def_node, - DexField* put_field_def, - IRInstruction* sput) { - for (auto& usage_node : def_node->edges) { - auto usage = std::get(usage_node->item); - if (usage.insn == sput) { - auto pair = m_pass_state->sfield_to_node.emplace(put_field_def, - usage_node.get()); - if (pair.second) { - TRACE(WP, - 2, - " field %s will map to usage %p", - SHOW(put_field_def), - usage_node.get()); - } else { - auto ptr = pair.first->second; - TRACE(WP, - 2, - " field %s has redundant put; previous usage node %p will " - "take effect", - SHOW(put_field_def), - ptr); - } - } - } - } - - // For a def that was instantiated by the method, emit a node and attach to - // the pass state's representation. - void emit_new_instance_node(const int64_t constant, - IRInstruction* new_instance, - DexField* put_field_def, - IRInstruction* sput) { - auto invoke_ctor = find_invoke_ctor(new_instance); - Source source{new_instance, invoke_ctor, m_method, constant}; - auto node = std::make_unique(); - node->item = source; - // Find all users of the new-instance, add edges. - attach_usage_nodes(node, {invoke_ctor}); - // Track sput-object specially, as explained above. - store_sput_node_pointer(node, put_field_def, sput); - // Connect this to the forest. - m_pass_state->forest.add_node(std::move(node)); - } - - // For a def that was from an sget, emit a node and attach to the pass state's - // representation. - void emit_sget_node(IRInstruction* sget, - DexField* put_field_def, - IRInstruction* sput) { - auto resolved_get_field_def = - resolve_field(sget->get_field(), FieldSearch::Static); - always_assert_log(resolved_get_field_def != nullptr, - "Unable to resolve field from instruction %s", - SHOW(sget)); - - Usage sget_usage{sget, m_method}; - auto node = std::make_unique(); - node->item = sget_usage; - - // Find all users of the sget. - attach_usage_nodes(node); - // Track sput-object specially, as explained above. - store_sput_node_pointer(node, put_field_def, sput); - // Connect this to the appropriate parent. - auto& parent = m_pass_state->sfield_to_node.at(resolved_get_field_def); - parent->add_edge(std::move(node)); - } - - // Follow-up work after running the fixpoint iterator. Implementation specific - virtual void post_analyze() {} - - void run(const InstructionAnalyzer& insn_analyzer) { - auto& cfg = get_cfg(); - TRACE(WP, 3, "Analyzing %s %s", SHOW(m_method), SHOW(cfg)); - m_fp_iter = std::make_unique< - constant_propagation::intraprocedural::FixpointIterator>( - /* cp_state */ nullptr, cfg, insn_analyzer); - m_fp_iter->run(ConstantEnvironment()); - post_analyze(); - } - - constant_propagation::intraprocedural::FixpointIterator* - get_fixpoint_iterator() { - return m_fp_iter.get(); - } - - protected: - const std::unordered_map& m_wrapper_types; - PassState* m_pass_state; - DexClass* m_cls; - DexMethod* m_method; - Lazy m_live_ranges; - - std::unique_ptr - m_fp_iter; -}; -} // namespace wrapped_primitives - -// A wrapped primitive is a type with a constructor taking a primitive, that is -// largely used to achieve some special kind of type safety above just a -// primitive. Configurations will specify the wrapper type name, and APIs that -// it is sanctioned to be used in. For wrapper instances that can be replaced -// directly with the primitive itself safely (based on easily understood -// instantiation and no unsupported usages) this pass will make modifications. -class WrappedPrimitivesPass : public Pass { - public: - WrappedPrimitivesPass() : Pass("WrappedPrimitivesPass") {} - - redex_properties::PropertyInteractions get_property_interactions() - const override { - using namespace redex_properties::interactions; - using namespace redex_properties::names; - return { - {DexLimitsObeyed, Preserves}, - {NoResolvablePureRefs, Preserves}, - {UltralightCodePatterns, Preserves}, - {InitialRenameClass, Preserves}, - }; - } - - void bind_config() override; - void eval_pass(DexStoresVector&, ConfigFiles&, PassManager&) override; - void run_pass(DexStoresVector&, ConfigFiles&, PassManager&) override; - void unset_roots(); - - private: - std::vector m_wrapper_specs; - // Config driven optimization will create inbound references to new methods. - // These methods need to not be deleted. - std::unordered_set m_marked_root_classes; - std::unordered_set m_marked_root_methods; -}; diff --git a/opt/wrapped-primitives/WrappedPrimitivesPass.cpp b/opt/wrapped-primitives/WrappedPrimitivesPass.cpp new file mode 100644 index 00000000000..081ad32f832 --- /dev/null +++ b/opt/wrapped-primitives/WrappedPrimitivesPass.cpp @@ -0,0 +1,142 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "WrappedPrimitivesPass.h" + +#include + +#include "DexUtil.h" +#include "PassManager.h" +#include "Show.h" +#include "Trace.h" +#include "Walkers.h" +#include "WrappedPrimitives.h" + +namespace cp = constant_propagation; +namespace mog = method_override_graph; +namespace wp = wrapped_primitives; + +namespace { + +constexpr const char* METRIC_CONSTS_INSERTED = "const_instructions_inserted"; +constexpr const char* METRIC_CASTS_INSERTED = "check_casts_inserted"; + +// Check assumptions about the wrapper class's hierarchy. +void validate_wrapper_type(DexType* type) { + auto cls = type_class(type); + always_assert(cls != nullptr); + always_assert_log(cls->get_interfaces()->empty(), + "Wrapper type %s should not implement interfaces", + SHOW(type)); + auto super_cls = cls->get_super_class(); + always_assert_log(super_cls == type::java_lang_Object(), + "Wrapper type %s should inherit from Object; got %s", + SHOW(type), SHOW(super_cls)); +} + +// A wrapped primitive is assumed to be represented by the only final primitive +// field in the wrapper class. +DexType* get_wrapped_final_field_type(DexType* type) { + auto cls = type_class(type); + always_assert_log(cls != nullptr, "Spec class %s not found", SHOW(type)); + std::vector candidates; + for (auto& f : cls->get_ifields()) { + if (is_final(f) && type::is_primitive(f->get_type())) { + candidates.emplace_back(f); + } + } + always_assert_log(candidates.size() == 1, + "Expected 1 final field of primitive type in class %s", + SHOW(cls)); + return candidates.at(0)->get_type(); +} +} // namespace + +void WrappedPrimitivesPass::bind_config() { + std::vector wrappers; + std::vector wrapper_specs; + bind("wrappers", {}, wrappers); + for (auto it = wrappers.begin(); it != wrappers.end(); ++it) { + const auto& value = *it; + always_assert_log(value.isObject(), + "Wrong specification: spec in array not an object."); + JsonWrapper json_obj = JsonWrapper(value); + wp::Spec spec; + std::string wrapper_desc; + json_obj.get("wrapper", "", wrapper_desc); + spec.wrapper = DexType::get_type(wrapper_desc); + always_assert_log(spec.wrapper != nullptr, "Type %s does not exist", + wrapper_desc.c_str()); + // Ensure the wrapper type matches expectations by the pass. + validate_wrapper_type(spec.wrapper); + spec.primitive = get_wrapped_final_field_type(spec.wrapper); + + // Unpack an array of objects, each object is just a 1 key/value to map an + // API using the wrapper type to the corresponding API of primitive type. + Json::Value allowed_invokes_array; + json_obj.get("allowed_invokes", Json::Value(), allowed_invokes_array); + always_assert_log( + allowed_invokes_array.isArray(), + "Wrong specification: allowed_invokes must be an array of objects."); + for (auto& obj : allowed_invokes_array) { + always_assert_log( + obj.isObject(), + "Wrong specification: allowed_invokes must be an array of objects."); + auto members = obj.getMemberNames(); + always_assert_log( + members.size() == 1, + "Wrong specification: allowed invoke object should be just 1 mapping " + "of method ref string to method ref string."); + auto api = members.at(0); + TRACE(WP, 2, "Checking for API '%s'", api.c_str()); + auto wrapped_api = DexMethod::get_method(api); + always_assert_log(wrapped_api != nullptr, "Method %s does not exist", + api.c_str()); + std::string unwrapped_api_desc; + JsonWrapper jobj = JsonWrapper(obj); + jobj.get(api.c_str(), "", unwrapped_api_desc); + always_assert_log(!unwrapped_api_desc.empty(), "empty!"); + TRACE(WP, 2, "Checking for unwrapped API '%s'", + unwrapped_api_desc.c_str()); + auto unwrapped_api = DexMethod::get_method(unwrapped_api_desc); + always_assert_log(unwrapped_api != nullptr, "Method %s does not exist", + unwrapped_api_desc.c_str()); + spec.allowed_invokes.emplace(wrapped_api, unwrapped_api); + TRACE(WP, 2, "Allowed API call %s -> %s", SHOW(wrapped_api), + SHOW(unwrapped_api)); + } + wrapper_specs.emplace_back(spec); + } + wp::initialize(wrapper_specs); + trait(Traits::Pass::unique, true); +} + +void WrappedPrimitivesPass::eval_pass(DexStoresVector& stores, + ConfigFiles& conf, + PassManager&) { + wp::get_instance()->mark_roots(); +} + +void WrappedPrimitivesPass::run_pass(DexStoresVector& stores, + ConfigFiles& /* unused */, + PassManager& mgr) { + auto wp_instance = wp::get_instance(); + wp_instance->unmark_roots(); + + auto consts = wp_instance->consts_inserted(); + TRACE(WP, 1, "const instructions inserted: %zu", consts); + mgr.set_metric(METRIC_CONSTS_INSERTED, consts); + + auto casts = wp_instance->casts_inserted(); + TRACE(WP, 1, "check-cast instructions inserted: %zu", casts); + mgr.set_metric(METRIC_CASTS_INSERTED, casts); + + // Clear state so that no futher work gets done from multiple rounds of IPCP + wp::initialize({}); +} + +static WrappedPrimitivesPass s_pass; diff --git a/opt/wrapped-primitives/WrappedPrimitivesPass.h b/opt/wrapped-primitives/WrappedPrimitivesPass.h new file mode 100644 index 00000000000..2e5b46ad122 --- /dev/null +++ b/opt/wrapped-primitives/WrappedPrimitivesPass.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include "Pass.h" + +// A wrapped primitive is a type with a constructor taking a primitive, that is +// largely used to achieve some special kind of type safety above just a +// primitive. Configurations will specify the wrapper type name, and APIs that +// it is sanctioned to be used in. For wrapper instances that can be replaced +// directly with the primitive itself safely (based on easily understood +// instantiation) this pass will make modifications. +class WrappedPrimitivesPass : public Pass { + public: + WrappedPrimitivesPass() : Pass("WrappedPrimitivesPass") {} + + redex_properties::PropertyInteractions get_property_interactions() + const override { + using namespace redex_properties::interactions; + using namespace redex_properties::names; + return { + {DexLimitsObeyed, Preserves}, + {NoResolvablePureRefs, Preserves}, + {UltralightCodePatterns, Preserves}, + {InitialRenameClass, Preserves}, + }; + } + + void bind_config() override; + void eval_pass(DexStoresVector&, ConfigFiles&, PassManager&) override; + void run_pass(DexStoresVector&, ConfigFiles&, PassManager&) override; +}; diff --git a/service/wrapped-primitives/WrappedPrimitives.cpp b/service/wrapped-primitives/WrappedPrimitives.cpp new file mode 100644 index 00000000000..1a473d2051b --- /dev/null +++ b/service/wrapped-primitives/WrappedPrimitives.cpp @@ -0,0 +1,268 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "WrappedPrimitives.h" + +#include +#include + +#include "CFGMutation.h" +#include "ConstantEnvironment.h" +#include "ConstantPropagationAnalysis.h" +#include "ConstantPropagationState.h" +#include "ConstantPropagationWholeProgramState.h" +#include "ConstructorParams.h" +#include "DexUtil.h" +#include "IPConstantPropagationAnalysis.h" +#include "InitDeps.h" +#include "LiveRange.h" +#include "MethodOverrideGraph.h" +#include "PassManager.h" +#include "RedexContext.h" +#include "Show.h" +#include "Trace.h" +#include "TypeSystem.h" +#include "Walkers.h" +#include "WorkQueue.h" + +namespace wrapped_primitives { +static std::unique_ptr s_instance{nullptr}; +WrappedPrimitives* get_instance() { return s_instance.get(); } + +void initialize(const std::vector& wrapper_specs) { + s_instance.reset(new WrappedPrimitives(wrapper_specs)); + // In tests, we create and destroy g_redex repeatedly. So we need to reset + // the singleton. + g_redex->add_destruction_task([]() { s_instance.reset(nullptr); }); +} + +void WrappedPrimitives::mark_roots() { + for (auto& spec : m_wrapper_specs) { + for (auto&& [from, to] : spec.allowed_invokes) { + auto def = to->as_def(); + if (def != nullptr && def->rstate.can_delete()) { + TRACE(WP, 2, "Setting %s as root", SHOW(def)); + def->rstate.set_root(); + m_marked_root_methods.emplace(def); + auto cls = type_class(def->get_class()); + if (cls->rstate.can_delete()) { + TRACE(WP, 2, "Setting %s as root", SHOW(cls)); + cls->rstate.set_root(); + m_marked_root_classes.emplace(cls); + } + } + } + for (auto& method : spec.wrapper_type_constructors()) { + if (!method->rstate.dont_inline()) { + method->rstate.set_dont_inline(); + TRACE(WP, 2, "Disallowing inlining for %s", SHOW(method)); + } + } + } +} + +void WrappedPrimitives::unmark_roots() { + for (auto& def : m_marked_root_methods) { + TRACE(WP, 2, "Unsetting %s as root", SHOW(def)); + def->rstate.unset_root(); + } + for (auto& cls : m_marked_root_classes) { + TRACE(WP, 2, "Unsetting %s as root", SHOW(cls)); + cls->rstate.unset_root(); + } +} + +namespace { +bool contains_relevant_invoke( + const std::unordered_set& wrapped_apis, + DexMethod* method) { + if (wrapped_apis.empty()) { + return false; + } + auto& cfg = method->get_code()->cfg(); + auto iterable = cfg::InstructionIterable(cfg); + for (auto it = iterable.begin(); it != iterable.end(); ++it) { + IRInstruction* insn = it->insn; + if (insn->has_method() && wrapped_apis.count(insn->get_method()) > 0) { + return true; + } + } + return false; +} + +// Checks if the value is a known ObjectWithImmutAttr with a single known +// attribute value. Makes assumptions that there is only 1, as is consistent +// with the other assumptions in the pass. +boost::optional> +extract_object_with_attr_value(const ConstantValue& value) { + auto obj_or_none = value.maybe_get(); + if (obj_or_none != boost::none && + obj_or_none->get_constant() != boost::none) { + auto object = *obj_or_none->get_constant(); + always_assert(object.attributes.size() == 1); + auto signed_value = + object.attributes.front().value.maybe_get(); + if (signed_value != boost::none && + signed_value.value().get_constant() != boost::none) { + auto primitive_value = *signed_value.value().get_constant(); + return std::pair(object.type, primitive_value); + } else { + TRACE(WP, 2, " No SignedConstantDomain value"); + } + } else { + TRACE(WP, 2, " Not a known ObjectWithImmutAttrDomain"); + } + + return boost::none; +} + +bool needs_cast(const TypeSystem& type_system, + DexMethodRef* from_ref, + DexMethodRef* to_ref) { + auto from = from_ref->get_class(); + auto to = to_ref->get_class(); + if (from == to) { + return false; + } + if (is_interface(type_class(from))) { + auto supers = type_system.get_all_super_interfaces(from); + return supers.count(to) == 0; + } else { + if (is_interface(type_class(to))) { + return !type_system.implements(from, to); + } else { + return !type_system.is_subtype(to, from); + } + } +} +} // namespace + +void WrappedPrimitives::increment_consts() { m_consts_inserted++; } + +void WrappedPrimitives::increment_casts() { m_casts_inserted++; } + +void WrappedPrimitives::optimize_method( + const TypeSystem& type_system, + const cp::intraprocedural::FixpointIterator& intra_cp, + const cp::WholeProgramState& wps, + DexMethod* method, + cfg::ControlFlowGraph& cfg) { + if (method->get_code() == nullptr || method->rstate.no_optimizations()) { + return; + } + if (!contains_relevant_invoke(m_all_wrapped_apis, method)) { + return; + } + + TRACE(WP, 2, "optimize_method: %s", SHOW(method)); + cfg::CFGMutation mutation(cfg); + for (const auto& block : cfg.blocks()) { + auto env = intra_cp.get_entry_state_at(block); + // This block is unreachable + if (env.is_bottom()) { + continue; + } + auto last_insn = block->get_last_insn(); + auto ii = InstructionIterable(block); + for (auto it = ii.begin(); it != ii.end(); it++) { + auto cfg_it = block->to_cfg_instruction_iterator(it); + auto insn = cfg_it->insn; + + if (insn->has_method() && + m_all_wrapped_apis.count(insn->get_method()) > 0) { + TRACE(WP, 2, "Relevant invoke: %s", SHOW(insn)); + // Inline the wrapped constant value and change method ref. + auto srcs_size = insn->srcs_size(); + auto& reg_env = env.get_register_environment(); + bool changed_ref{false}; + for (size_t i = 0; i < srcs_size; i++) { + auto current_reg = insn->src(i); + TRACE(WP, 2, " Checking v%d", current_reg); + auto& value = reg_env.get(current_reg); + auto maybe_pair = extract_object_with_attr_value(value); + if (maybe_pair != boost::none) { + auto wrapper_type = maybe_pair->first; + auto literal = maybe_pair->second; + TRACE(WP, + 2, + " ** Instruction %s uses a known object with constant " + "value %" PRId64, + SHOW(insn), + literal); + + auto search = + m_type_to_spec.find(const_cast(wrapper_type)); + auto ref = insn->get_method(); + if (search != m_type_to_spec.end()) { + auto spec = search->second; + if (spec.allowed_invokes.count(ref)) { + auto unwrapped_ref = spec.allowed_invokes.at(ref); + auto is_wide = type::is_wide_type(spec.primitive); + auto literal_reg = + is_wide ? cfg.allocate_wide_temp() : cfg.allocate_temp(); + auto const_insn = (new IRInstruction(is_wide ? OPCODE_CONST_WIDE + : OPCODE_CONST)) + ->set_literal(literal) + ->set_dest(literal_reg); + + mutation.insert_before(cfg_it, {const_insn}); + increment_consts(); + insn->set_src(i, literal_reg); + if (!changed_ref) { + if (needs_cast(type_system, ref, unwrapped_ref)) { + auto to_type = unwrapped_ref->get_class(); + auto opcode = is_interface(type_class(to_type)) + ? OPCODE_INVOKE_INTERFACE + : OPCODE_INVOKE_VIRTUAL; + auto obj_reg = cfg.allocate_temp(); + auto cast = (new IRInstruction(OPCODE_CHECK_CAST)) + ->set_type(to_type) + ->set_src(0, insn->src(0)); + auto move_pseudo = + (new IRInstruction(IOPCODE_MOVE_RESULT_PSEUDO_OBJECT)) + ->set_dest(obj_reg); + insn->set_method(unwrapped_ref); + insn->set_opcode(opcode); + insn->set_src(0, obj_reg); + mutation.insert_before(cfg_it, {cast, move_pseudo}); + increment_casts(); + } else { + insn->set_method(unwrapped_ref); + } + changed_ref = true; + } + } + } + } + } + } + intra_cp.analyze_instruction(insn, &env, insn == last_insn->insn); + } + } + mutation.flush(); +} + +bool is_wrapped_api(const DexMethodRef* ref) { + auto wp_instance = get_instance(); + if (wp_instance == nullptr) { + return false; + } + return wp_instance->is_wrapped_api(ref); +} + +void optimize_method(const TypeSystem& type_system, + const cp::intraprocedural::FixpointIterator& intra_cp, + const cp::WholeProgramState& wps, + DexMethod* method, + cfg::ControlFlowGraph& cfg) { + auto wp_instance = get_instance(); + if (wp_instance == nullptr) { + return; + } + wp_instance->optimize_method(type_system, intra_cp, wps, method, cfg); +} +} // namespace wrapped_primitives diff --git a/service/wrapped-primitives/WrappedPrimitives.h b/service/wrapped-primitives/WrappedPrimitives.h new file mode 100644 index 00000000000..440e9fbec9f --- /dev/null +++ b/service/wrapped-primitives/WrappedPrimitives.h @@ -0,0 +1,132 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include "ConstantEnvironment.h" +#include "ConstantPropagationAnalysis.h" +#include "ConstantPropagationState.h" +#include "ConstantPropagationWholeProgramState.h" +#include "ControlFlow.h" +#include "DexClass.h" +#include "GlobalConfig.h" +#include "PassManager.h" +#include "Trace.h" +#include "TypeSystem.h" + +namespace wrapped_primitives { +// A config driven spec describing wrapper classes to look for, each of which is +// asserted to have 1 final field of some primitive type. Beyond this, +// assumptions also include: +// 1) A constructor taking 1 argument which is the primitive it wraps. +// 2) Wrapper class extends java.lang.Object and does not implement interfaces. +// +// Wrapper class instances that can effectively be "unboxed" by this pass must +// conform to a very narrow set of usages. Currently, supported uses are: +// - Wrapper class can be instantiated with a known constant (known means +// intraprocedural constant propagation can easily figure it out). +// - Wrapper class instances can be written to static final fields. +// - Wrapper class instances can be retrieved from static final fields. +// - Wrapper class instances can be an argument to a set of configured "allowed +// - invokes" i.e. method refs that they can be passed to. +// +// Finally, the input program must honor guarantees about the allowed method +// invocations. For the output program to type check properly, it must be +// explicitly listed for every allowed API taking the wrapper class, what is the +// corresponding primitive API that should be swapped in. It is up to the author +// of the input program to ensure that this works in practice, otherwise Redex +// is free to fail in whatever way it chooses (i.e. fail the build or optimize +// no wrapper types). +// +// EXAMPLE: +// "LFoo;.a:(LMyLong;)V" is an allowed invoke, the config should map this to +// something like "LFoo;.a:(J)V" which will also need to exist in the input +// program. This is the simplest form. If however, the allowed invoke maps to an +// API on a different type, say from an interface method to a method on the +// interface's underlying implenentor, check-cast instructions may need to be +// inserted to make this work. It's up to the program's authors to ensure this +// ends up as a working app (and we may fail the build otherwise, or insert +// casts that would fail at runtime if things are misconfigured). +struct Spec { + DexType* wrapper{nullptr}; + DexType* primitive{nullptr}; + std::map allowed_invokes; + + std::vector wrapper_type_constructors() { + auto cls = type_class(wrapper); + return cls->get_ctors(); + } +}; + +namespace cp = constant_propagation; +class WrappedPrimitives { + public: + explicit WrappedPrimitives(const std::vector& wrapper_specs) + : m_wrapper_specs(wrapper_specs) { + for (auto& spec : wrapper_specs) { + TRACE(WP, + 1, + "Will check for wrapper type %s with supported methods:", + SHOW(spec.wrapper)); + auto wrapper_cls = type_class(spec.wrapper); + always_assert(wrapper_cls != nullptr); + m_type_to_spec.emplace(spec.wrapper, spec); + for (auto&& [from, to] : spec.allowed_invokes) { + TRACE(WP, 1, " %s", SHOW(from)); + } + for (auto&& [from, to] : spec.allowed_invokes) { + m_all_wrapped_apis.emplace(from); + } + } + } + void mark_roots(); + void unmark_roots(); + void optimize_method(const TypeSystem& type_system, + const cp::intraprocedural::FixpointIterator& intra_cp, + const cp::WholeProgramState& wps, + DexMethod* method, + cfg::ControlFlowGraph& cfg); + // Stats + size_t consts_inserted() { return m_consts_inserted; } + size_t casts_inserted() { return m_casts_inserted; } + // Convenience methods; + bool is_wrapped_api(const DexMethodRef* ref) { + return m_all_wrapped_apis.count(ref) > 0; + } + + private: + void increment_consts(); + void increment_casts(); + std::vector m_wrapper_specs; + std::unordered_map m_type_to_spec; + std::unordered_set m_all_wrapped_apis; + // Config driven optimization will create inbound references to new methods. + // These methods need to not be deleted. + std::unordered_set m_marked_root_classes; + std::unordered_set m_marked_root_methods; + // Concurrent stats + std::atomic m_consts_inserted{0}; + std::atomic m_casts_inserted{0}; +}; + +// Users should be talking to the singleton which is set up to be operational +// across the pass list. +WrappedPrimitives* get_instance(); +void initialize(const std::vector& wrapper_specs); + +// Simple checks for other passes to see if state has been configured. +bool is_wrapped_api(const DexMethodRef* ref); + +// Simplified entry point of optimizing a method, if configured. +void optimize_method(const TypeSystem& type_system, + const cp::intraprocedural::FixpointIterator& intra_cp, + const cp::WholeProgramState& wps, + DexMethod* method, + cfg::ControlFlowGraph& cfg); +} // namespace wrapped_primitives diff --git a/test/instr/WrappedPrimitives.java b/test/instr/WrappedPrimitives.java index 093a5ea7f7c..4eef0d86bba 100644 --- a/test/instr/WrappedPrimitives.java +++ b/test/instr/WrappedPrimitives.java @@ -7,6 +7,8 @@ package com.facebook.redex; +import java.util.Arrays; + class Constants { public static final long ONE = 1L; public static final long TWO = 2L; @@ -24,6 +26,10 @@ class MyLong { public MyLong(long value) { this.value = value; } + + public static MyLong make(long value) { + return new MyLong(value); + } }; class Bad { @@ -31,7 +37,7 @@ static void escape(MyLong l) {} } class AllValues { - public static final MyLong L1 = new MyLong(Constants.ONE); + public static final MyLong L1 = MyLong.make(Constants.ONE); public static final MyLong L2 = new MyLong(Constants.TWO); public static final MyLong L3 = new MyLong(Constants.THREE); public static final MyLong L4 = new MyLong(Constants.FOUR); @@ -73,7 +79,12 @@ class MoreValues { } } -class Receiver { +interface Safe { + long getLong(MyLong l); + long peekLong(MyLong l); +} + +class Receiver implements Safe { public long getLong(MyLong l) { return l.value; } @@ -97,6 +108,9 @@ static void markFetched(MyLong l) { } public class WrappedPrimitives { + + private static final Object LOCK = new Object(); + public static long[] run() { long[] results = new long[8]; Receiver r = new Receiver(); @@ -112,4 +126,46 @@ public static long[] run() { results[7] = r.getLong(MoreValues.L9); return results; } + + public static long simple(Receiver r) { + return r.getLong(AllValues.L1); + } + + public static long simpleCast(Safe s) { + return s.getLong(AllValues.L1); + } + + // Another expected usage; Interface type is given, and will need a check-cast + // to underlying impl which has the unwrapped method. Will need to ensure that + // monitor-enter/exit instructions are properly balanced under such an + // insertion + public static synchronized long[] runMonitor(Safe s) { + long[] results = new long[1]; + long l; + synchronized (LOCK) { + l = s.getLong(AllValues.L1); + } + results[0] = l; + return results; + } + + public static synchronized long[] runAnother(Safe s) { + String tag = "X"; + long[] results = new long[1]; + try { + results[0] = s.getLong(AllValues.L1); + } catch (IllegalStateException e) { + android.util.Log.w(tag, e); + } + return results; + } + + public static long[] runWithInterface() { + Receiver r = new Receiver(); + long[] one = runMonitor(r); + long[] two = runAnother(r); + long[] result = Arrays.copyOf(one, one.length + two.length); + System.arraycopy(two, 0, result, one.length, two.length); + return result; + } } diff --git a/test/instr/WrappedPrimitivesTest.java b/test/instr/WrappedPrimitivesTest.java index d6c5c10f8d1..9cf6af05595 100644 --- a/test/instr/WrappedPrimitivesTest.java +++ b/test/instr/WrappedPrimitivesTest.java @@ -24,4 +24,12 @@ public void testTranformableLookup() { assertThat(results[6]).isEqualTo(8); assertThat(results[7]).isEqualTo(9); } + + @Test + public void testTranformableCastLookup() { + long[] results = com.facebook.redex.WrappedPrimitives.runWithInterface(); + assertThat(results.length).isEqualTo(2); + assertThat(results[0]).isEqualTo(1); + assertThat(results[1]).isEqualTo(1); + } } diff --git a/test/instr/WrappedPrimitivesTestVerify.cpp b/test/instr/WrappedPrimitivesTestVerify.cpp index 76f195bd0e4..21061daac3b 100644 --- a/test/instr/WrappedPrimitivesTestVerify.cpp +++ b/test/instr/WrappedPrimitivesTestVerify.cpp @@ -13,51 +13,67 @@ using namespace testing; namespace { -std::set SUPPORTED_FIELDS{"L1", "L4", "L8"}; -std::set UNSUPPORTED_FIELDS{"L2", "L3", "L5", "L6", "L7"}; -} // namespace +void dump_method(DexMethod* method) { + method->balloon(); + method->get_code()->build_cfg(); + auto& cfg = method->get_code()->cfg(); + std::cout << show(method) << " " << show(cfg) << std::endl; +} -TEST_F(PreVerify, VerifyBaseState) { - auto wrapped_cls = find_class_named(classes, "Lcom/facebook/redex/MyLong;"); - auto wrapped_type = wrapped_cls->get_type(); - auto cls = find_class_named(classes, "Lcom/facebook/redex/AllValues;"); - - std::vector all_fields; - all_fields.insert(all_fields.end(), SUPPORTED_FIELDS.begin(), - SUPPORTED_FIELDS.end()); - all_fields.insert(all_fields.end(), UNSUPPORTED_FIELDS.begin(), - UNSUPPORTED_FIELDS.end()); - for (const auto& name : all_fields) { - auto f = find_sfield_named(*cls, name.c_str()); - EXPECT_NE(f, nullptr) << "Did not find field " << name; - EXPECT_EQ(f->get_type(), wrapped_type); +std::string stringify_for_comparision(DexMethod* method) { + method->balloon(); + // Remove positions to make asserts easier to write with IRAssembler. + auto code = method->get_code(); + for (auto it = code->begin(); it != code->end();) { + if (it->type == MFLOW_POSITION) { + it = code->erase_and_dispose(it); + } else { + it++; + } } + return assembler::to_string(code); } +} // namespace TEST_F(PostVerify, VerifyTransform) { - auto wrapped_cls = find_class_named(classes, "Lcom/facebook/redex/MyLong;"); - auto wrapped_type = wrapped_cls->get_type(); - auto primitive_long = type::_long(); - auto cls = find_class_named(classes, "Lcom/facebook/redex/AllValues;"); - for (const auto& name : SUPPORTED_FIELDS) { - auto f = find_sfield_named(*cls, name.c_str()); - EXPECT_NE(f, nullptr) << "Did not find field " << name; - EXPECT_EQ(f->get_type(), primitive_long) - << "Field " << SHOW(f) << " should be unboxed!"; + auto usage_cls = + find_class_named(classes, "Lcom/facebook/redex/WrappedPrimitives;"); + + // Simple unboxing. + { + auto simple = find_method_named(*usage_cls, "simple"); + auto simple_str = stringify_for_comparision(simple); + auto expected = assembler::ircode_from_string(R"(( + (load-param-object v2) + (sget-object "Lcom/facebook/redex/AllValues;.L1:Lcom/facebook/redex/MyLong;") + (move-result-pseudo-object v0) + (const-wide v0 1) + (invoke-virtual (v2 v0) "Lcom/facebook/redex/Receiver;.getLong:(J)J") + (move-result-wide v0) + (return-wide v0) + ))"); + EXPECT_EQ(simple_str, assembler::to_string(expected.get())); } - for (const auto& name : UNSUPPORTED_FIELDS) { - auto f = find_sfield_named(*cls, name.c_str()); - EXPECT_NE(f, nullptr) << "Did not find field " << name; - EXPECT_EQ(f->get_type(), wrapped_type) - << "Field " << SHOW(f) << " should be unchanged!"; + + // Insertion of a cast to the underlying unwrapped API. + { + auto simple_cast = find_method_named(*usage_cls, "simpleCast"); + auto simple_cast_str = stringify_for_comparision(simple_cast); + auto expected = assembler::ircode_from_string(R"(( + (load-param-object v2) + (sget-object "Lcom/facebook/redex/AllValues;.L1:Lcom/facebook/redex/MyLong;") + (move-result-pseudo-object v0) + (const-wide v0 1) + (check-cast v2 "Lcom/facebook/redex/Receiver;") + (move-result-pseudo-object v2) + (invoke-virtual (v2 v0) "Lcom/facebook/redex/Receiver;.getLong:(J)J") + (move-result-wide v0) + (return-wide v0) + ))"); + EXPECT_EQ(simple_cast_str, assembler::to_string(expected.get())); } - auto usage_cls = - find_class_named(classes, "Lcom/facebook/redex/WrappedPrimitives;"); - auto run = find_method_named(*usage_cls, "run"); - EXPECT_NE(run, nullptr); - run->balloon(); - run->get_code()->build_cfg(); - auto& cfg = run->get_code()->cfg(); - std::cout << show(cfg) << std::endl; + // Just for convenience, dump some methods as a much more readable CFG form. + dump_method(find_method_named(*usage_cls, "run")); + dump_method(find_method_named(*usage_cls, "runMonitor")); } diff --git a/test/instr/wrappedprimitives.config b/test/instr/wrappedprimitives.config index 2411a2c3be7..94079f210b8 100644 --- a/test/instr/wrappedprimitives.config +++ b/test/instr/wrappedprimitives.config @@ -9,13 +9,27 @@ }, { "Lcom/facebook/redex/Receiver;.peekLong:(Lcom/facebook/redex/MyLong;)J": "Lcom/facebook/redex/Receiver;.peekLong:(J)J" + }, + { + "Lcom/facebook/redex/Safe;.getLong:(Lcom/facebook/redex/MyLong;)J": "Lcom/facebook/redex/Receiver;.getLong:(J)J" + }, + { + "Lcom/facebook/redex/Safe;.peekLong:(Lcom/facebook/redex/MyLong;)J": "Lcom/facebook/redex/Receiver;.peekLong:(J)J" } ] } ] }, + "InterproceduralConstantPropagationPass": { + "create_runtime_asserts": false, + "include_virtuals": true, + "max_heap_analysis_iterations": 3, + "replace_moves_with_consts": true, + "use_multiple_callee_callgraph": true + }, "redex" : { "passes" : [ + "InterproceduralConstantPropagationPass", "WrappedPrimitivesPass", "RegAllocPass" ]