Skip to content

Commit

Permalink
[CPU] Refactor memory control and allocation
Browse files Browse the repository at this point in the history
  • Loading branch information
EgorDuplensky committed Oct 26, 2024
1 parent 0c07136 commit 28be812
Show file tree
Hide file tree
Showing 19 changed files with 364 additions and 265 deletions.
9 changes: 6 additions & 3 deletions src/plugins/intel_cpu/src/compiled_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "openvino/runtime/threading/cpu_streams_info.hpp"
#include "openvino/runtime/threading/cpu_message.hpp"
#include "utils/serialize.hpp"
#include "memory_control.hpp"

#include "cpu/x64/cpu_isa_traits.hpp"
#include <cstring>
Expand Down Expand Up @@ -52,7 +53,8 @@ CompiledModel::CompiledModel(const std::shared_ptr<ov::Model>& model,
m_cfg{cfg},
m_name{model->get_name()},
m_loaded_from_cache(loaded_from_cache),
m_sub_memory_manager(sub_memory_manager) {
m_sub_memory_manager(sub_memory_manager),
m_networkMemoryControl(std::make_shared<NetworkMemoryControl>()) {
m_mutex = std::make_shared<std::mutex>();
const auto& core = m_plugin->get_core();
if (!core)
Expand Down Expand Up @@ -157,13 +159,14 @@ CompiledModel::GraphGuard::Lock CompiledModel::get_graph() const {
std::lock_guard<std::mutex> lock{*m_mutex.get()};
auto isQuantizedFlag = (m_cfg.lpTransformsMode == Config::On) &&
ov::pass::low_precision::LowPrecision::isFunctionQuantized(m_model);

ctx = std::make_shared<GraphContext>(m_cfg,
m_socketWeights[socketId],
isQuantizedFlag,
m_networkMemoryControl->createMemoryControlUnit(),
streamsExecutor,
m_sub_memory_manager);
}

const std::shared_ptr<const ov::Model> model = m_model;
graphLock._graph.CreateGraph(model, ctx);
} catch (...) {
Expand Down Expand Up @@ -346,7 +349,7 @@ void CompiledModel::release_memory() {
for (auto&& graph : m_graphs) {
GraphGuard::Lock graph_lock{graph};
auto ctx = graph_lock._graph.getGraphContext();
ctx->getNetworkMemoryControl()->releaseMemory();
m_networkMemoryControl->releaseMemory();
}
}

Expand Down
8 changes: 8 additions & 0 deletions src/plugins/intel_cpu/src/compiled_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#pragma once

#include <memory>
#include <string>
#include <vector>

Expand All @@ -19,6 +20,8 @@
namespace ov {
namespace intel_cpu {

class NetworkMemoryControl;

class CompiledModel : public ov::ICompiledModel {
public:
typedef std::shared_ptr<CompiledModel> Ptr;
Expand Down Expand Up @@ -51,6 +54,10 @@ class CompiledModel : public ov::ICompiledModel {

void release_memory() override;

std::shared_ptr<NetworkMemoryControl> get_network_memory_control() const {
return m_networkMemoryControl;
}

private:
std::shared_ptr<ov::ISyncInferRequest> create_sync_infer_request() const override;
friend class SyncInferRequest;
Expand Down Expand Up @@ -91,6 +98,7 @@ class CompiledModel : public ov::ICompiledModel {

std::vector<std::shared_ptr<CompiledModel>> m_sub_compiled_models;
std::shared_ptr<SubMemoryManager> m_sub_memory_manager = nullptr;
std::shared_ptr<NetworkMemoryControl> m_networkMemoryControl = nullptr;
bool m_has_sub_compiled_models = false;
};

Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_cpu/src/edge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ void Edge::validate() {
getChild();

if (status != Status::Allocated || !memoryPtr) {
OPENVINO_THROW("Error memory is not allocated!");
OPENVINO_THROW("Error memory is not allocated for edge: ", name());
}
status = Status::Validated;
}
Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_cpu/src/edge.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class Edge {
}

std::string name() const;
const MemoryDesc& getDesc() const;

private:
std::weak_ptr<Node> parent;
Expand All @@ -99,7 +100,6 @@ class Edge {
PortDescBaseCPtr getInputPortDesc() const;
PortDescBaseCPtr getOutputPortDesc() const;

const MemoryDesc& getDesc() const;
bool enforceReorder();

void collectConsumers(std::vector<std::shared_ptr<Node>>& result) const;
Expand Down
Loading

0 comments on commit 28be812

Please sign in to comment.