diff --git a/nums/core/application_manager.py b/nums/core/application_manager.py index 62050f68..977844b1 100644 --- a/nums/core/application_manager.py +++ b/nums/core/application_manager.py @@ -16,11 +16,12 @@ import logging import sys - +import time from nums.core import settings from nums.core.systems.filesystem import FileSystem from nums.core.systems import numpy_compute from nums.core.systems.systems import System, SerialSystem, RaySystem +from nums.core.systems.gpu_systems import CupyParallelSystem from nums.core.systems.schedulers import RayScheduler, TaskScheduler, BlockCyclicScheduler from nums.core.array.application import ArrayApplication @@ -71,6 +72,11 @@ def create(): use_head=settings.use_head) system: System = RaySystem(compute_module=compute_module, scheduler=scheduler) + elif system_name == "cupy-parallel": + system = CupyParallelSystem() + system.optimizer = settings.optimizer + system.num_gpus = settings.num_gpus + system.cluster_shape = (settings.num_gpus, 1) else: raise Exception() system.init() @@ -82,10 +88,16 @@ def destroy(): if _instance is None: return # This will shutdown ray if ray was started by NumS. - _instance.system.shutdown() - del _instance - _instance = None + system = _instance.system + del _instance.one_half + del _instance.two + del _instance.one + del _instance.zero + del _instance + system.shutdown() + # _instance.system.shutdown() + def configure_logging(): root = logging.getLogger() diff --git a/nums/core/array/application.py b/nums/core/array/application.py index 9a98a6aa..2ec0dd8e 100644 --- a/nums/core/array/application.py +++ b/nums/core/array/application.py @@ -23,12 +23,12 @@ from nums.core.storage.storage import ArrayGrid, StoredArray, StoredArrayS3 # TODO(hme): Remove dependence on specific system and scheduler implementations. from nums.core.systems.systems import System, RaySystem, SerialSystem +from nums.core.systems.gpu_systems import CupyParallelSystem from nums.core.systems.schedulers import BlockCyclicScheduler from nums.core.systems import utils as systems_utils from nums.core.systems.filesystem import FileSystem from nums.core.array.random import NumsRandomState - # pylint: disable = too-many-lines @@ -51,6 +51,9 @@ def num_cores_total(self): system: RaySystem = self.system nodes = system.nodes() num_cores = sum(map(lambda n: n["Resources"]["CPU"], nodes)) + elif isinstance(self.system, CupyParallelSystem): + system: CupyParallelSystem = self.system + num_cores = system.num_gpus else: assert isinstance(self.system, SerialSystem) num_cores = systems_utils.get_num_cores() @@ -93,6 +96,8 @@ def compute_block_shape(self, and isinstance(self.system.scheduler, BlockCyclicScheduler): # This configuration is the default. cluster_shape = self.system.scheduler.cluster_shape + elif isinstance(self.system, CupyParallelSystem): + cluster_shape = self.system.cluster_shape else: assert isinstance(self.system, SerialSystem) cluster_shape = (1, 1) diff --git a/nums/core/array/blockarray.py b/nums/core/array/blockarray.py index 5243f7ad..5a977e65 100644 --- a/nums/core/array/blockarray.py +++ b/nums/core/array/blockarray.py @@ -48,7 +48,7 @@ def from_scalar(cls, val, system): if isinstance(val, int): dtype = np.int elif isinstance(val, float): - dtype = np.float + dtype = np.float32 else: assert isinstance(val, (np.int32, np.int64, np.float32, np.float64)) dtype = None @@ -122,7 +122,6 @@ def touch(self): for grid_entry in self.grid.get_entry_iterator(): block: Block = self.blocks[grid_entry] oids.append(self.system.touch(block.oid, syskwargs=block.syskwargs())) - self.system.get(oids) return self def reshape(self, *shape, **kwargs): @@ -714,7 +713,7 @@ def __rpow__(self, other): return other ** self def __neg__(self): - return -1 * self + return -1.0 * self def __pos__(self): return self diff --git a/nums/core/optimizer/cluster_sim.py b/nums/core/optimizer/cluster_sim.py index d3117a3d..8a899099 100644 --- a/nums/core/optimizer/cluster_sim.py +++ b/nums/core/optimizer/cluster_sim.py @@ -52,28 +52,14 @@ def get_cluster_node_ids(self): def get_cluster_entry_iterator(self): return itertools.product(*map(range, self.cluster_shape)) - def get_cluster_entry(self, grid_entry): - cluster_entry = [] - num_grid_entry_axes = len(grid_entry) - num_cluster_axes = len(self.cluster_shape) - if num_grid_entry_axes <= num_cluster_axes: - # When array has fewer or equal # of axes than cluster. - for cluster_axis in range(num_cluster_axes): - if cluster_axis < num_grid_entry_axes: - cluster_dim = self.cluster_shape[cluster_axis] - grid_entry_dim = grid_entry[cluster_axis] - cluster_entry.append(grid_entry_dim % cluster_dim) - else: - cluster_entry.append(0) - elif num_grid_entry_axes > num_cluster_axes: - # When array has more axes then cluster. - for cluster_axis in range(num_cluster_axes): - cluster_dim = self.cluster_shape[cluster_axis] - grid_entry_dim = grid_entry[cluster_axis] - cluster_entry.append(grid_entry_dim % cluster_dim) - # Ignore trailing axes, as these are "cycled" to 0 by assuming - # the dimension of those cluster axes is 1. - return tuple(cluster_entry) + def get_cluster_entry(self, grid_entry, grid_shape): + ret = [0] + for i in range(len(grid_entry)): + dim = 1 if i == len(grid_entry) - 1 else grid_shape[i+1] + ret[0] = (ret[0] + grid_entry[i]) * dim + ret[0] = ret[0] % self.system.num_gpus + ret.append(0) + return tuple(ret) # Block Ops. diff --git a/nums/core/optimizer/comp_graph.py b/nums/core/optimizer/comp_graph.py index f4bc6b92..5551e70f 100644 --- a/nums/core/optimizer/comp_graph.py +++ b/nums/core/optimizer/comp_graph.py @@ -750,7 +750,7 @@ def graphs_from_ba(ba: BlockArrayBase, cluster_state: ClusterState, copy_on_op) for grid_entry in ba.grid.get_entry_iterator(): block: Block = ba.blocks[grid_entry] # Allocate the block to the node on which it's created. - node_id = cluster_state.get_cluster_entry(block.true_grid_entry()) + node_id = cluster_state.get_cluster_entry(block.true_grid_entry(), ba.grid.grid_shape) cluster_state.add_block(block, node_ids=[node_id]) cluster_state.init_mem_load(node_id, block.id) diff --git a/nums/core/optimizer/tree_search.py b/nums/core/optimizer/tree_search.py index 406b8a33..045a3396 100644 --- a/nums/core/optimizer/tree_search.py +++ b/nums/core/optimizer/tree_search.py @@ -55,17 +55,17 @@ def get_bc_action(self, tnode: TreeNode): # This is hacky, but no good way to do it w/ current abstractions. if isinstance(tnode, BinaryOp): grid_entry = self.get_tnode_grid_entry(tnode) - node_id = self.arr.cluster_state.get_cluster_entry(grid_entry) + node_id = self.arr.cluster_state.get_cluster_entry(grid_entry, self.arr.grid.grid_shape) actions = [(tnode.tree_node_id, {"node_id": node_id})] elif isinstance(tnode, ReductionOp): leaf_ids = tuple(tnode.leafs_dict.keys())[:2] grid_entry = self.get_tnode_grid_entry(tnode) - node_id = self.arr.cluster_state.get_cluster_entry(grid_entry) + node_id = self.arr.cluster_state.get_cluster_entry(grid_entry, self.arr.grid.grid_shape) actions = [(tnode.tree_node_id, {"node_id": node_id, "leaf_ids": leaf_ids})] elif isinstance(tnode, UnaryOp): grid_entry = self.get_tnode_grid_entry(tnode) - node_id = self.arr.cluster_state.get_cluster_entry(grid_entry) + node_id = self.arr.cluster_state.get_cluster_entry(grid_entry, self.arr.grid.grid_shape) actions = [(tnode.tree_node_id, {"node_id": node_id})] else: raise Exception() @@ -120,9 +120,8 @@ def simulate_action(self, action): return self.objective(new_resources) def objective(self, resources): - max_axes = tuple(np.arange(1, len(self.arr.cluster_state.cluster_shape) + 1)) # Our simple objective. - return np.sum(np.max(resources, axis=max_axes)) + return np.sum(resources[1:]) def get_tnode_grid_entry(self, tnode: TreeNode): if tnode.parent is None: diff --git a/nums/core/settings.py b/nums/core/settings.py index 2c2a0c4e..fdb4f399 100644 --- a/nums/core/settings.py +++ b/nums/core/settings.py @@ -28,7 +28,13 @@ # System settings. -system_name = os.environ.get("NUMS_SYSTEM", "ray-cyclic") +# system_name = os.environ.get("NUMS_SYSTEM", "ray-cyclic") + +# Parallel system only uses the following three values +system_name = os.environ.get("NUMS_SYSTEM", "cupy-parallel") +num_gpus = 4 +optimizer = True + # TODO (hme): # - Make cluster shape an environment variable. Default depends on available resources. # - use_head => use_driver, and should be an environment variable. diff --git a/nums/core/systems/cupy_compute.py b/nums/core/systems/cupy_compute.py new file mode 100644 index 00000000..cca57f5c --- /dev/null +++ b/nums/core/systems/cupy_compute.py @@ -0,0 +1,287 @@ +# coding=utf-8 +# Copyright (C) 2020 NumS Development Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import random +import cupy as cp +from numpy.random import PCG64 +from numpy.random import Generator +import scipy.linalg +import scipy.special + +from nums.core.storage.storage import ArrayGrid +from nums.core.systems.interfaces import ComputeImp, RNGInterface +from nums.core.settings import np_ufunc_map + + +def block_rng(seed, jump_index): + return Generator(PCG64(seed).jumped(jump_index)) + + +class RNG(RNGInterface): + # A naive approach to implementing a parallel random state is to simply + # increment the random seed given by the user, but this + # will lead to many collisions if the user is also incrementing the random seed, + # so a more principled approach is needed. + # In particular, our generator should work just as serial generators would + # if a block array is generated with a random seed of 0, and then + # a random seed of 1. + # See: https://numpy.org/doc/stable/reference/random/parallel.html + # The easiest parallel RNG approach for NumS is to jump the BitGenerator state. + # The way this works is as follows: + # The user provides a seed, which we use to seed all bit generators. + # Whenever a new array is sampled, increment the jump index for each block in the array, + # and in the remote function, sample the block shape using a newly initialized + # random state with the user provided seed and the jump index created for the block. + # Blocks are sampled for new arrays by incrementing the jump index as before. + # This can be viewed simply as incrementing the jump index whenever a new block needs to be + # sampled, regardless of the array the block belongs to. + # A global random state is maintained, just like in numpy, so that + # a random state is not required to sample numbers. + # The seed can be set for the global random state, which + # re-instantiates the global random state. + + # One issue is the following: + # nrs1 = NPRandom(1337) + # nrs2 = NPRandom(1337) + # x1 = nrs1.sample(shape=(100,), block_shape=(10,)) + # x2 = nrs1.sample(shape=(100,), block_shape=(11,)) + # x1 != x2 + # x1 performs more jumps because it has more blocks. + # One way to remedy this is to always sample using the default internal block shape, + # and then reshape to the required block shape. + # This is of course sub-optimal -- we could alternatively jump according to a small + # block shape, and generate as many jumps needed to sample the required block shape, + # but this is tedious, and since we won't expose block_shape as a parameter in the final + # api, the proposed approach works fine as-is. + + def __init__(self, seed=None, jump_index=0): + # pylint: disable=no-member + if seed is None: + seed = random.getrandbits(128) + self.seed = seed + self.rng = PCG64(seed) + self.jump_index = jump_index + + def new_block_rng_params(self): + params = self.seed, self.jump_index + self.jump_index += 1 + return params + + +class ComputeCls(ComputeImp): + # pylint: disable=no-member, unused-variable + + # I/O operations. + def touch(self, arr): + return isinstance(arr, cp.ndarray) + + def empty(self, grid_entry, grid_meta): + grid = ArrayGrid.from_meta(grid_meta) + block_shape = grid.get_block_shape(grid_entry) + return cp.empty(block_shape, dtype=grid.dtype) + + def new_block(self, op_name, grid_entry, grid_meta): + op_func = cp.__getattribute__(op_name) + grid = ArrayGrid.from_meta(grid_meta) + block_shape = grid.get_block_shape(grid_entry) + if op_name == "eye": + assert cp.all(cp.diff(grid_entry) == 0) + return op_func(*block_shape, dtype=grid.dtype) + else: + return op_func(block_shape, dtype=grid.dtype) + + def random_block(self, rng_params, rfunc_name, rfunc_args, shape, dtype): + rng: Generator = block_rng(*rng_params) + op_func = rng.__getattribute__(rfunc_name) + result = op_func(*rfunc_args).reshape(shape) + if rfunc_name not in ("random", "integers"): + # Only random and integer supports sampling of a specific type. + result = result.astype(dtype) + return cp.array(result) + + def permutation(self, rng_params, size): + rng: Generator = block_rng(*rng_params) + return rng.permutation(size) + + def create_block(self, *src_arrs, src_params, dst_params, dst_shape, dst_shape_bc): + result = cp.empty(shape=dst_shape, dtype=src_arrs[0].dtype) + assert len(src_params) == len(dst_params) + for i in range(len(src_params)): + src_arr: cp.ndarray = src_arrs[i] + src_sel, srcT = src_params[i] + if srcT: + src_arr = src_arr.T + dst_sel, dstT = dst_params[i] + if dst_shape_bc is not None: + result.reshape(dst_shape_bc)[dst_sel] = src_arr[src_sel] + else: + result[dst_sel] = src_arr[src_sel] + return result + + def update_block(self, dst_arr, *src_arrs, src_params, dst_params): + assert len(src_params) == len(dst_params) + # We need to copy here. If we modify this after a no-copy assignment + # of a block from array A to B, modifying B will modify the contents of A. + dst_arr = dst_arr.copy() + _, dstT = dst_params[0] + if dstT: + dst_arr = dst_arr.T + for i in range(len(src_params)): + src_arr: cp.ndarray = src_arrs[i] + src_sel, src_shape_bc, srcT = src_params[i] + if srcT: + src_arr = src_arr.T + dst_sel, dstT = dst_params[i] + if src_shape_bc is not None: + dst_arr[dst_sel] = src_arr.reshape(src_shape_bc)[src_sel] + else: + dst_arr[dst_sel] = src_arr[src_sel] + return dst_arr + + def update_block_by_index(self, dst_arr, src_arr, index_pairs): + result = dst_arr.copy() + for dst_index, src_index in index_pairs: + result[tuple(dst_index)] = src_arr[tuple(src_index)] + return result + + def update_block_along_axis(self, dst_arr, src_arr, index_pairs, axis): + # Assume shape along axes != axis are of equal dim. + result = dst_arr.copy() + dst_sel = [slice(None, None)] * len(dst_arr.shape) + src_sel = [slice(None, None)] * len(src_arr.shape) + for dst_index, src_index in index_pairs: + dst_sel[axis] = dst_index + src_sel[axis] = src_index + result[tuple(dst_sel)] = src_arr[tuple(src_sel)] + return result + + def diag(self, arr): + return cp.diag(arr) + + def arange(self, start, stop, step, dtype): + return cp.arange(start, stop, step, dtype) + + def reduce_axis(self, op_name, arr, axis, keepdims, transposed): + op_func = cp.__getattribute__(op_name) + if transposed: + arr = arr.T + return op_func(arr, axis=axis, keepdims=keepdims) + + # This is essentially a map. + def map_uop(self, op_name, arr, args, kwargs): + ufunc = cp.__getattribute__(op_name) + return ufunc(arr, *args, **kwargs) + + def where(self, arr, x, y, block_slice_tuples): + if x is None: + assert y is None + res = cp.where(arr) + for i, (start, stop) in enumerate(block_slice_tuples): + arr = res[i] + arr += start + else: + assert isinstance(x, cp.ndarray) and isinstance(y, cp.ndarray) + res = cp.where(arr, x, y) + shape = res[0].shape + res = list(res) + res.append(shape) + return tuple(res) + + def xlogy(self, arr_x, arr_y): + return scipy.special.xlogy(arr_x, arr_y) + + def astype(self, arr, dtype_str): + dtype = getattr(cp, dtype_str) + return arr.astype(dtype) + + def sum_reduce(self, *arrs): + from functools import reduce + return reduce(cp.add, arrs) + + def transpose(self, arr): + return arr.T + + def split(self, arr, indices_or_sections, axis, transposed): + if transposed: + arr = arr.T + return cp.split(arr, indices_or_sections, axis) + + def bop(self, op, a1, a2, a1_shape, a2_shape, a1_T, a2_T, axes): + if a1_T: + a1 = a1.T + if a2_T: + a2 = a2.T + if a1.shape != a1_shape: + a1 = a1.reshape(a1_shape) + if a2.shape != a2_shape: + a2 = a2.reshape(a2_shape) + + if op == "tensordot": + return cp.tensordot(a1, a2, axes=axes) + op = np_ufunc_map.get(op, op) + try: + ufunc = cp.__getattribute__(op) + except Exception as _: + ufunc = scipy.special.__getattribute__(op) + return ufunc(a1, a2) + + def qr(self, *arrays, mode="reduced", axis=None): + if len(arrays) > 1: + assert axis is not None + arr = cp.concatenate(arrays, axis=axis) + else: + arr = arrays[0] + return cp.linalg.qr(arr, mode=mode) + + def cholesky(self, arr): + return cp.linalg.cholesky(arr) + + def svd(self, arr): + u, sigma, vT = cp.linalg.svd(arr) + u = u[:sigma.shape[0]] + return u, sigma, vT + + def inv(self, arr): + return cp.linalg.inv(arr) + + # Boolean + + def allclose(self, a: cp.ndarray, b: cp.ndarray, rtol, atol): + return cp.allclose(a, b, rtol, atol) + + # Logic + + def logical_and(self, *bool_list): + return cp.all(bool_list) + + def arg_op(self, op_name, arr, block_slice, other_argoptima=None, other_optima=None): + if op_name == "argmin": + arr_argmin = cp.argmin(arr) + arr_min = arr[arr_argmin] + if other_optima is not None and other_optima < arr_min: + return other_argoptima, other_optima + return block_slice.start + arr_argmin, arr_min + elif op_name == "argmax": + arr_argmax = cp.argmax(arr) + arr_max = arr[arr_argmax] + if other_optima is not None and other_optima > arr_max: + return other_argoptima, other_optima + return block_slice.start + arr_argmax, arr_max + else: + raise Exception("Unsupported arg op.") + + def reshape(self, arr, shape): + return arr.reshape(shape) diff --git a/nums/core/systems/gpu_systems.py b/nums/core/systems/gpu_systems.py new file mode 100644 index 00000000..194a7a96 --- /dev/null +++ b/nums/core/systems/gpu_systems.py @@ -0,0 +1,266 @@ +import sys +import functools +import time +import itertools +import gc +from collections import defaultdict +from typing import Tuple +import numpy as np +import ray + +from nums.core.systems import numpy_compute +from nums.core.settings import np_ufunc_map +from nums.core.systems.interfaces import RNGInterface +from nums.core.systems.utils import extract_functions + + +def cupy_used_bytes(): + import cupy as cp + mempool = cp.get_default_memory_pool() + return mempool.used_bytes() + + +class BaseGPUSystem(object): + def __init__(self): + for name in ['random_block', 'new_block', 'update_block', 'create_block', + 'sum_reduce', 'map_uop', 'reshape', 'inv', 'empty', 'reduce_axis', + 'astype', 'bop']: + setattr(self, name, functools.partial(self.call_compute_interface, name)) + + def get_rng(self, seed) -> RNGInterface: + from nums.core.systems import numpy_compute + self.rng_cls = numpy_compute.RNG + return self.rng_cls(seed) + + def init(self): + pass + + def shutdown(self): + pass + + def register(self, name: str, func: callable, remote_params: dict = None): + pass + + def call_compute_interface(self, name, *args, **kwargs): + raise NotImplementedError + + +############################################################## +############ SerialSystem: Serial implementation ############# +############################################################## +class SerialSystem(BaseGPUSystem): + def __init__(self, compute_module): + # Init ComputeInterface + self.compute_imp = compute_module.ComputeCls() + super().__init__() + + def call_compute_interface(self, name, *args, **kwargs): + del kwargs['syskwargs'] + #if name in ['bop', 'map_uop']: + # print(f"SerialSystem::call compute {name} {args[0]}") + #else: + # print(f"SerialSystem::call compute {name}") + ret = getattr(self.compute_imp, name)(*args, **kwargs) + #print(f"SerialSystem::result {ret.shape} {cupy_used_bytes()/1e9} {ret.dtype}") + return ret + + +class NumpySerialSystem(SerialSystem): + def __init__(self, num_gpus): + super().__init__(numpy_compute) + + def put(self, x): + return x + + def get(self, x): + return x + + def touch(self, object_id, syskwargs): + return object_id + + +class CupySerialSystem(SerialSystem): + def __init__(self, num_gpus): + import cupy as cp + from nums.core.systems import cupy_compute + + self.cp = cp + super().__init__(cupy_compute) + + def put(self, x): + return self.cp.array(x) + + def get(self, x): + self.cp.cuda.Device(0).synchronize() + if isinstance(x, list): + return [a.get() for a in x] + else: + return x.get() + + def touch(self, object_id, syskwargs): + self.cp.cuda.Device(0).synchronize() + return object_id + + def shutdown(self): + mempool = self.cp.get_default_memory_pool() + mempool.free_all_blocks() + +############################################################## +########## ParallelSystem: Parallel implementation ########### +############################################################## +class CupySystemArrRef: + def __init__(self, cp_arr, system): + self.cp_arr = cp_arr + self.system = system + + def __del__(self): + self.system.delete(self.cp_arr) + + +class CupyParallelSystem(BaseGPUSystem): + def __init__(self, local_cache=True, immediate_gc=False): + import cupy as cp + from nums.core.systems import cupy_compute + + self.cp = cp + self.num_gpus = 1 + self.local_cache = local_cache + self.immediate_gc = immediate_gc + self.cluster_shape = (self.num_gpus, 1) + self.optimizer = True + self.compute_imp = cupy_compute.ComputeCls() + self.dist_dict = defaultdict(dict) # Dict[hash(array) -> Dict[actor_id -> array]] + super().__init__() + + def put(self, x): + with self.cp.cuda.Device(0): + ret = self.cp.array(x) + ret = self._register_new_array(ret, 0) + + for actor_id in range(1, self.num_gpus): + self._distribute_to(ret, actor_id) + + return CupySystemArrRef(ret, self) + + def get(self, x): + if isinstance(x, list): + return [a.cp_arr.get() for a in x] + else: + return x.cp_arr.get() + + def touch(self, x, syskwargs): + x.cp_arr.device.synchronize() + return x + + def get_cluster_entry(self, grid_entry, grid_shape): + ret = [0] + for i in range(len(grid_entry)): + dim = 1 if i == len(grid_entry) - 1 else grid_shape[i+1] + ret[0] = (ret[0] + grid_entry[i]) * dim + ret[0] = ret[0] % self.num_gpus + ret.append(0) + return tuple(ret) + + def call_with_options(self, name, args, kwargs, options): + dst_actor = options["dst_actor"] + # print(f"CupyParallelSystem::call compute {args} on {dst_actor}") + + args = [self._distribute_to(v.cp_arr, dst_actor) + if isinstance(v, CupySystemArrRef) else v for v in args] + kwargs = {k: self._distribute_to(v.cp_arr, dst_actor) + if isinstance(v, CupySystemArrRef) else v for k, v in kwargs.items()} + + with self.cp.cuda.Device(dst_actor): + # print(f"CupyParallelSystem::call args {args} kwargs {kwargs}") + ret = getattr(self.compute_imp, name)(*args, **kwargs) + + if self.immediate_gc: + self.dist_dict = defaultdict(dict) + else: + ret = self._register_new_array(ret, dst_actor) + return CupySystemArrRef(ret, self) + + def call_compute_interface(self, name, *args, **kwargs): + # Make device placement decisions + syskwargs = kwargs.pop('syskwargs') + grid_entry = syskwargs["grid_entry"] + grid_shape = syskwargs["grid_shape"] + + if self.optimizer: + cluster_entry: tuple = self.get_cluster_entry(grid_entry, grid_shape) + # print(f"CupyParallelSystem::call grid entry {grid_entry} and cluster entry {cluster_entry}") + dst_actor = cluster_entry[0] + else: + if name == 'bop': + dst_actor = None + for arg in itertools.chain(args, kwargs.values()): + if isinstance(arg, CupySystemArrRef): + dst_actor = arg.cp_arr.data.device_id + break + else: + gid = get_flatten_id(grid_entry, grid_shape) + dst_actor = gid % self.num_gpus + + options = {} + options["dst_actor"] = dst_actor + return self.call_with_options(name, args, kwargs, options) + + def distribute_to(self, arr_ref, dst_actor): + return self._distribute_to(arr_ref.cp_arr, dst_actor) + + def _distribute_to(self, arr, dst_actor): + if self.local_cache: + arr_hash = self._get_array_hash(arr) + ret = self.dist_dict[arr_hash].get(dst_actor, None) + if ret is None: + if arr.data.device_id == dst_actor: + ret = arr + else: + # print(f"Copy {arr.shape} from {arr.data.device_id} to {dst_actor}") + with self.cp.cuda.Device(dst_actor): + ret = self.cp.asarray(arr) + self.dist_dict[arr_hash][dst_actor] = ret + else: + if arr.data.device_id == dst_actor: + ret = arr + else: + with self.cp.cuda.Device(dst_actor): + ret = self.cp.asarray(arr) + + return ret + + def _get_array_hash(self, arr): + return (arr.data.device_id, arr.data.mem, arr.data.ptr) + + def _register_new_array(self, arr, dst_actor): + if self.local_cache: + self.dist_dict[self._get_array_hash(arr)][dst_actor] = arr + return arr + else: + return arr + + def get_options(self, cluster_entry, cluster_shape): + node_entry = self.get_cluster_entry(cluster_entry, cluster_shape) + return { + "dst_actor": node_entry[0] + } + + def delete(self, arr): + if not self.immediate_gc: + if self.dist_dict is not None: + del self.dist_dict[self._get_array_hash(arr)] + + def shutdown(self): + self.dist_dict = None + mempool = self.cp.get_default_memory_pool() + mempool.free_all_blocks() + + +def get_flatten_id(grid_entry, grid_shape): + ret = 0 + for i in range(len(grid_entry)): + dim = 1 if i == len(grid_entry) - 1 else grid_shape[i+1] + ret = (ret + grid_entry[i]) * dim + + return ret + diff --git a/nums/experimental/benchmark_mlp_data.py b/nums/experimental/benchmark_mlp_data.py new file mode 100644 index 00000000..f6808b40 --- /dev/null +++ b/nums/experimental/benchmark_mlp_data.py @@ -0,0 +1,317 @@ +import argparse +import time + +import numpy as np + +from nums import numpy as nps +from nums.core.array.application import ArrayApplication +from nums.core.array.blockarray import BlockArray +from nums.core.optimizer.cluster_sim import ClusterState +from nums.core.optimizer.comp_graph import GraphArray +from nums.core.systems.gpu_systems import ( + NumpySerialSystem, + CupySerialSystem, + CupyParallelSystem, +) +from nums.core import application_manager as am +from nums.core import settings +from utils import benchmark_func, get_number_of_gpus +import lr_opt as opt + +random_seed = 1337 + + +def cupy_used_bytes(): + import cupy as cp + + mempool = cp.get_default_memory_pool() + return mempool.used_bytes() + + +def forward(app, X, W): + Z = opt.collapse_graph_array(app, X @ W) + return Z + + +def sigmoid(app, one, X): + return one / (one + app.exp(-X)) + + +def sigmoid_deriv(one, Z): + return Z * (one - Z) + + +def sigmoid_opt(app, X, one): + return opt.collapse_graph_array(app, one / (one + app.exp(-X))) + + +def sigmoid_deriv_opt(app, Z, one): + return opt.collapse_graph_array(app, Z * (one - Z)) + + +def one_step_fit_common(app, one, X, y, W_in_1, W_1_2, W_2_out): + LR = one + Z_1 = X @ W_in_1 + + S_1 = sigmoid(app, one, Z_1) + F_1 = sigmoid_deriv(one, Z_1).T + + Z_2 = S_1 @ W_1_2 + S_2 = sigmoid(app, one, Z_2) + F_2 = sigmoid_deriv(one, Z_2).T + + Z_out = S_2 @ W_2_out + F_out = sigmoid_deriv(one, Z_out).T + y_predict = sigmoid(app, one, Z_out) + + # --back propagation-- + D_out = F_out * (y_predict - y).T + D_2 = F_2 * (W_2_out @ D_out) + D_1 = F_1 * (W_1_2 @ D_2) + + W_in_1 = W_in_1 - LR * (D_1 @ X).T + W_1_2 = W_1_2 - LR * (D_2 @ S_1).T + W_2_out = W_2_out - LR * (D_out @ S_2).T + + return W_in_1, W_1_2, W_2_out + + +def one_step_fit_np(np, X, y, W_in_1, W_1_2, W_2_out): + rets = one_step_fit_common(np, 1, X, y, W_in_1, W_1_2, W_2_out) + + +def one_step_fit(app, X, y, W_in_1, W_1_2, W_2_out): + rets = one_step_fit_common(app, app.one, X, y, W_in_1, W_1_2, W_2_out) + + for x in rets: + x.touch() + + +def distribute_weights(app, W, cluster_state): + for node_id in cluster_state.get_cluster_node_ids(): + for grid_entry in W.grid.get_entry_iterator(): + from nums.core.array.base import Block + block: Block = W.blocks[grid_entry] + if node_id not in cluster_state.get_block_node_ids(block.id): + dst_actor = node_id[0] + app.system.distribute_to(block.oid, dst_actor) # copy for compute + cluster_state.commit_copy_block(block.id, node_id) # copy for optimizer + + +def one_step_fit_opt_data_parallel(app, X, y, W_in_1, W_1_2, W_2_out, num_gpus, verbose=False): + # --forward propagation-- + if verbose: + print("start forward propagation") + LR = app.one + cluster_state = ClusterState((num_gpus, 1), app.system) + one_ga: GraphArray = GraphArray.from_ba(app.one, cluster_state) + X_ga = GraphArray.from_ba(X, cluster_state) + y_ga = GraphArray.from_ba(y, cluster_state) + W_in_1_ga = GraphArray.from_ba(W_in_1, cluster_state) + W_1_2_ga = GraphArray.from_ba(W_1_2, cluster_state) + W_2_out_ga = GraphArray.from_ba(W_2_out, cluster_state) + + if verbose: + print(f"distribute weights") + # Distribute Weights + distribute_weights(app, W_in_1, cluster_state) + distribute_weights(app, W_1_2, cluster_state) + distribute_weights(app, W_2_out, cluster_state) + + Z_1_ga: GraphArray = forward(app, X_ga, W_in_1_ga) + S_1_ga: GraphArray = sigmoid_opt(app, Z_1_ga, one_ga) + F_1_ga: GraphArray = sigmoid_deriv_opt(app, Z_1_ga, one_ga) + + if verbose: + print("forward Z_2") + Z_2_ga: GraphArray = forward(app, S_1_ga, W_1_2_ga) + S_2_ga: GraphArray = sigmoid_opt(app, Z_2_ga, one_ga) + F_2_ga: GraphArray = sigmoid_deriv_opt(app, Z_2_ga, one_ga) + if verbose: + print("forward Z_out") + Z_out_ga: GraphArray = forward(app, S_2_ga, W_2_out_ga) + y_predict_ga: GraphArray = sigmoid_opt(app, Z_out_ga, one_ga) + F_out_ga: GraphArray = sigmoid_deriv_opt(app, Z_out_ga, one_ga) + + # --back propagation-- + D_out_ga = opt.collapse_graph_array(app, F_out_ga.T * (y_predict_ga - y_ga).T) + if verbose: + print("collapse D_2_ga") + D_2_ga = opt.collapse_graph_array(app, F_2_ga.T * (W_2_out_ga @ D_out_ga)) + if verbose: + print("collapse D_1_ga") + D_1_ga = opt.collapse_graph_array(app, F_1_ga.T * (W_1_2_ga @ D_2_ga)) + + if verbose: + print("collapse_graph_array dW_in_1_ga") + + dW_in_1_ga = opt.collapse_graph_array(app, (D_1_ga @ X_ga).T) + if verbose: + print("collapse_graph_array dW_1_2_ga") + dW_1_2_ga = opt.collapse_graph_array(app, (D_2_ga @ S_1_ga).T) + if verbose: + print("collapse_graph_array dW_2_out_ga") + dW_2_out_ga = opt.collapse_graph_array(app, (D_out_ga @ S_2_ga).T) + + dW_in_1_ga_ba: BlockArray = opt.compute_graph_array(app, dW_in_1_ga) + dW_1_2_ga_ba: BlockArray = opt.compute_graph_array(app, dW_1_2_ga) + dW_2_out_ga_ba: BlockArray = opt.compute_graph_array(app, dW_2_out_ga) + + W_in_1 = W_in_1 - LR * dW_in_1_ga_ba + W_1_2 = W_1_2 - LR * dW_1_2_ga_ba + W_2_out = W_2_out - LR * dW_2_out_ga_ba + + W_in_1.touch() + W_1_2.touch() + W_2_out.touch() + + +def np_init_weights(app, X, y, dtype): + dim_1 = 4096 # neurons in the first layer + dim_2 = 4096 # neurons in the second layer + + W_in_1 = app.random.normal(size=(X.shape[1], dim_1)).astype(dtype) + W_1_2 = app.random.normal(size=(dim_1, dim_2)).astype(dtype) + W_2_out = app.random.normal(size=(dim_2, y.shape[1])).astype(dtype) + return W_in_1, W_1_2, W_2_out + + +def data_init_weights(app: ArrayApplication, X, y, verbose=False): + dim_1 = 4096 # neurons in the first layer + dim_2 = 4096 # neurons in the second layer + + W_in_1 = app.random.normal(shape=(X.shape[1], dim_1), block_shape=(X.block_shape[1], dim_1), dtype=X.dtype) + W_1_2 = app.random.normal(shape=(dim_1, dim_2), block_shape=(dim_1, dim_2), dtype=X.dtype) + W_2_out = app.random.normal(shape=(dim_2, y.shape[1]), block_shape=(dim_2, y.block_shape[1]), + dtype=X.dtype) + if verbose: + print(f"W_in_1.shape {W_in_1.shape} W_in_1.block_shape {W_in_1.block_shape}") + print(f"W_1_2.shape {W_1_2.shape} W_1_2.block_shape {W_1_2.block_shape}") + print(f"W_2_out.shape {W_2_out.shape} W_2_out.block_shape {W_2_out.block_shape}") + return W_in_1, W_1_2, W_2_out + + +def np_sample(app, sample_size, feature, dtype): + X_train = app.random.normal(size=(sample_size, feature)).astype(dtype) + y_train = app.ones((sample_size, 1), dtype=dtype) + return X_train, y_train + + +def sample(app: ArrayApplication, sample_size, feature, num_gpus, dtype): + X_train = app.random.normal(shape=(sample_size, feature), block_shape=(sample_size // num_gpus, feature), + dtype=dtype) + y_train = app.ones(shape=(sample_size, 1), block_shape=(sample_size // num_gpus, 1), dtype=dtype) + return X_train, y_train + + +def benchmark_mlp_data_parallel(num_gpus, N_list, system_class_list, d=1000, optimizer=True, dtype=np.float32): + format_string = "%20s,%10s,%10s,%10s" + print(format_string % ("Library", "N", "Cost", "CV")) + + for N in N_list: + N = int(N) + + for system_class in system_class_list: + # try: + if True: + if system_class in ["Cupy", "Numpy"]: + name = system_class + import cupy as cp + + arr_lib = cp if system_class == "Cupy" else np + app = arr_lib + + X, y = np_sample(np, sample_size=N, feature=d, dtype=dtype) + W_in_1, W_1_2, W_2_out = np_init_weights(np, X, y, dtype=dtype) + + X = cp.asarray(X) + y = cp.asarray(y) + W_in_1 = cp.asarray(W_in_1) + W_1_2 = cp.asarray(W_1_2) + W_2_out = cp.asarray(W_2_out) + + cp.cuda.Device(0).synchronize() + + # Benchmark one step mlp + def func(): + tic = time.time() + one_step_fit_np(arr_lib, X, y, W_in_1, W_1_2, W_2_out) + cp.cuda.Device(0).synchronize() + toc = time.time() + return toc - tic, None + + costs = benchmark_func(func) + del (X, y, W_in_1, W_1_2, W_2_out) + else: + # Init system + name = system_class.__name__ + app = am.instance() + + # Make dataset + nps.random.seed(0) + X, y = sample(app, sample_size=N, feature=d, num_gpus=num_gpus, dtype=dtype) + W_in_1, W_1_2, W_2_out = data_init_weights(app, X, y, verbose=False) + + # Benchmark one step MLP + def func(): + tic = time.time() + if optimizer: + one_step_fit_opt_data_parallel(app, X, y, W_in_1, W_1_2, W_2_out, num_gpus, verbose=False) + else: + one_step_fit(app, X, y, W_in_1, W_1_2, W_2_out) + toc = time.time() + return toc - tic, None + + costs = benchmark_func(func) + + del (X, y, W_in_1, W_1_2, W_2_out) + am.destroy() + # except Exception: + else: + costs = [-1] + + log_str = format_string % ( + name, + "%d" % N, + "%.4f" % np.mean(costs), + "%.2f" % (np.std(costs) / np.mean(costs)), + ) + + print(log_str) + with open("result_mlp_data.csv", "a") as f: + f.write(log_str + "\n") + + +if __name__ == "__main__": + num_gpus = settings.num_gpus + optimizer = settings.optimizer + benchmark_mlp_data_parallel( + num_gpus, + N_list=[ + # 2000, + # 4000, + # 8000, + # 16000, + # 32000, + # 40000, + # 42000, + 44000, + # 0.5e6 / 4, + # 1e6 / 4, + # 2e6 / 4, + # 3e6 / 4, + # 5e6 / 4, + # 10e6 / 4, + # 20e6 / 4, + # 40e6 / 4, + # 80e6 / 4, + # 160e6 / 4, + # 200e6 / 4, + ], + system_class_list=[ + CupyParallelSystem, + "Cupy", + # "Numpy" + ], + optimizer=optimizer, + ) diff --git a/nums/experimental/benchmark_mlp_model.py b/nums/experimental/benchmark_mlp_model.py new file mode 100644 index 00000000..b612c50c --- /dev/null +++ b/nums/experimental/benchmark_mlp_model.py @@ -0,0 +1,257 @@ +import argparse +import time + +import numpy as np + +from nums import numpy as nps +from nums.core.array.application import ArrayApplication +from nums.core.array.blockarray import BlockArray +from nums.core.array.base import Block +from nums.core.optimizer.cluster_sim import ClusterState +from nums.core.optimizer.comp_graph import GraphArray +from nums.core.systems.gpu_systems import ( + NumpySerialSystem, + CupySerialSystem, + CupyParallelSystem, +) +from nums.core import application_manager as am +from nums.core import settings +from utils import benchmark_func, get_number_of_gpus +import lr_opt as opt +from benchmark_mlp_data import forward, sigmoid_opt, sigmoid_deriv_opt, one_step_fit_np, one_step_fit, np_sample + +random_seed = 1337 + + +def distribute_graph_array(app, G, cluster_state): + for node_id in cluster_state.get_cluster_node_ids(): + for grid_entry in G.grid.get_entry_iterator(): + block: Block = cluster_state.get_block(G.graphs[grid_entry].block_id) + if node_id not in cluster_state.get_block_node_ids(block.id): + dst_actor = node_id[0] + # print(f"dst_actor{dst_actor}") + app.system.distribute_to(block.oid, dst_actor) # copy for compute + cluster_state.commit_copy_block(block.id, node_id) # copy for optimizer + + +def one_step_fit_opt_model_parallel(app, X, y, W_in_1, W_1_2, W_2_out, num_gpus, verbose=False): + # --forward propagation-- + LR = app.one + cluster_state = ClusterState((num_gpus, 1), app.system) + one_ga: GraphArray = GraphArray.from_ba(app.one, cluster_state) + X_ga = GraphArray.from_ba(X, cluster_state) + y_ga = GraphArray.from_ba(y, cluster_state) + W_in_1_ga = GraphArray.from_ba(W_in_1, cluster_state) + W_1_2_ga = GraphArray.from_ba(W_1_2, cluster_state) + W_2_out_ga = GraphArray.from_ba(W_2_out, cluster_state) + + if verbose: + print("forward Z_1_ga") + Z_1_ga: GraphArray = forward(app, X_ga, W_in_1_ga) + if verbose: + print("forward S_1_ga") + S_1_ga: GraphArray = sigmoid_opt(app, Z_1_ga, one_ga) + if verbose: + print("forward F_1_ga") + F_1_ga: GraphArray = sigmoid_deriv_opt(app, Z_1_ga, one_ga) + if verbose: + print("forward Z_2_ga") + Z_2_ga: GraphArray = forward(app, S_1_ga, W_1_2_ga) + S_2_ga: GraphArray = sigmoid_opt(app, Z_2_ga, one_ga) + F_2_ga: GraphArray = sigmoid_deriv_opt(app, Z_2_ga, one_ga) + if verbose: + print("forward Z_out_ga") + Z_out_ga: GraphArray = forward(app, S_2_ga, W_2_out_ga) + if verbose: + print("forward y_predict_ga") + y_predict_ga: GraphArray = sigmoid_opt(app, Z_out_ga, one_ga) + if verbose: + print("forward F_out_ga") + F_out_ga: GraphArray = sigmoid_deriv_opt(app, Z_out_ga, one_ga) + + # --back propagation-- + if verbose: + print("collapse D_out_ga") + D_out_ga = opt.collapse_graph_array(app, F_out_ga.T * (y_predict_ga - y_ga).T) + if verbose: + print("collapse D_2_ga") + D_2_ga = opt.collapse_graph_array(app, F_2_ga.T * (W_2_out_ga @ D_out_ga)) + if verbose: + print("collapse D_1_ga") + D_1_ga = opt.collapse_graph_array(app, F_1_ga.T * (W_1_2_ga @ D_2_ga)) + distribute_graph_array(app, D_1_ga, cluster_state) + if verbose: + print("collapse_graph_array dW_in_1_ga") + dW_in_1_ga = opt.collapse_graph_array(app, (D_1_ga @ X_ga).T) + if verbose: + print("collapse_graph_array dW_1_2_ga") + dW_1_2_ga = opt.collapse_graph_array(app, (D_2_ga @ S_1_ga).T) + if verbose: + print("collapse_graph_array dW_2_out_ga") + dW_2_out_ga = opt.collapse_graph_array(app, (D_out_ga @ S_2_ga).T) + + dW_in_1_ga_ba: BlockArray = opt.compute_graph_array(app, dW_in_1_ga) + dW_1_2_ga_ba: BlockArray = opt.compute_graph_array(app, dW_1_2_ga) + dW_2_out_ga_ba: BlockArray = opt.compute_graph_array(app, dW_2_out_ga) + + if verbose: + print("update W_in_1") + W_in_1 = W_in_1 - LR * dW_in_1_ga_ba + if verbose: + print("update W_1_2") + W_1_2 = W_1_2 - LR * dW_1_2_ga_ba + if verbose: + print("update W_2_out") + W_2_out = W_2_out - LR * dW_2_out_ga_ba + + W_in_1.touch() + W_1_2.touch() + W_2_out.touch() + + +def np_init_weights(app, X, y, d2, dtype): + dim_1 = 4096 # neurons in the first layer + dim_2 = d2 # neurons in the second layer + + W_in_1 = app.random.normal(size=(X.shape[1], dim_1)).astype(dtype) + W_1_2 = app.random.normal(size=(dim_1, dim_2)).astype(dtype) + W_2_out = app.random.normal(size=(dim_2, y.shape[1])).astype(dtype) + return W_in_1, W_1_2, W_2_out + + +def model_parallel_init_weights(app: ArrayApplication, num_gpus, X, y, d2, verbose=False): + dim_1 = 4096 # neurons in the first layer + dim_2 = d2 # neurons in the second layer + W_in_1 = app.random.normal(shape=(X.shape[1], dim_1), block_shape=(X.shape[1] // num_gpus, dim_1), dtype=X.dtype) + W_1_2 = app.random.normal(shape=(dim_1, dim_2), block_shape=(dim_1, dim_2 // num_gpus), dtype=X.dtype) + W_2_out = app.random.normal(shape=(dim_2, y.shape[1]), block_shape=(dim_2 // num_gpus, y.block_shape[1]), + dtype=X.dtype) + if verbose: + print(f"W_in_1.shape {W_in_1.shape} W_in_1.block_shape {W_in_1.block_shape}") + print(f"W_1_2.shape {W_1_2.shape} W_1_2.block_shape {W_1_2.block_shape}") + print(f"W_2_out.shape {W_2_out.shape} W_2_out.block_shape {W_2_out.block_shape}") + return W_in_1, W_1_2, W_2_out + + +def sample(app: ArrayApplication, sample_size, feature, num_gpus, dtype): + X_train = app.random.normal(shape=(sample_size, feature), block_shape=(sample_size, feature // num_gpus), + dtype=dtype) + y_train = app.ones(shape=(sample_size, 1), block_shape=(sample_size, 1), dtype=dtype) + return X_train, y_train + + +def benchmark_mlp_model_parallel(num_gpus, N_list, system_class_list, d=140000, optimizer=True, dtype=np.float32): + format_string = "%20s,%10s,%10s,%10s,%10s,%10s" + print(format_string % ("Library", "N", "d_in", "d_2", "Cost", "CV")) + + for N in N_list: + N = int(N) + d2 = 20000 + for system_class in system_class_list: + # try: + if True: + if system_class in ["Cupy", "Numpy"]: + name = system_class + import cupy as cp + + arr_lib = cp if system_class == "Cupy" else np + arr_lib.inv = arr_lib.linalg.inv + app = arr_lib + + X, y = np_sample(np, sample_size=N, feature=d, dtype=dtype) + W_in_1, W_1_2, W_2_out = np_init_weights(np, X, y, d2, dtype=dtype) + + X = cp.asarray(X) + y = cp.asarray(y) + W_in_1 = cp.asarray(W_in_1) + W_1_2 = cp.asarray(W_1_2) + W_2_out = cp.asarray(W_2_out) + + cp.cuda.Device(0).synchronize() + + # Benchmark one step mlp + def func(): + tic = time.time() + one_step_fit_np(arr_lib, X, y, W_in_1, W_1_2, W_2_out) + cp.cuda.Device(0).synchronize() + toc = time.time() + return toc - tic, None + + costs = benchmark_func(func) + del (X, y, W_in_1, W_1_2, W_2_out) + else: + # Init system + name = system_class.__name__ + app = am.instance() + + # Make dataset + nps.random.seed(0) + X, y = sample(app, sample_size=N, feature=d, num_gpus=num_gpus, dtype=dtype) + W_in_1, W_1_2, W_2_out = model_parallel_init_weights(app, num_gpus, X, y, d2, verbose=False) + + # Benchmark one step MLP + def func(): + tic = time.time() + if optimizer: + one_step_fit_opt_model_parallel(app, X, y, W_in_1, W_1_2, W_2_out, num_gpus) + else: + one_step_fit(app, X, y, W_in_1, W_1_2, W_2_out) + toc = time.time() + return toc - tic, None + + costs = benchmark_func(func) + + del (X, y, app, W_in_1, W_1_2, W_2_out) + # except Exception: + else: + costs = [-1] + + log_str = format_string % ( + name, + "%d" % N, + "%d" % d, + "%d" % d2, + "%.4f" % np.mean(costs), + "%.2f" % (np.std(costs) / np.mean(costs)), + ) + print(log_str) + with open("result_mlp_model.csv", "a") as f: + f.write(log_str + "\n") + + +if __name__ == "__main__": + num_gpus = settings.num_gpus + optimizer = settings.optimizer + benchmark_mlp_model_parallel( + num_gpus, + N_list=[ + 2000, + # 4096, + # 8192, + # 16384, + # 32768, + # 70000, + # 140000, + # 160000, + # 3000, + # 0.5e6 / 4, + # 1e6 / 4, + # 2e6 / 4, + # 3e6 / 4, + # 5e6 / 4, + # 10e6 / 4, + # 20e6 / 4, + # 40e6 / 4, + # 80e6 / 4, + # 160e6 / 4, + # 200e6 / 4, + ], + system_class_list=[ + "Cupy", + CupyParallelSystem, + # "Numpy", + ], + optimizer=optimizer, + ) + + diff --git a/nums/experimental/utils.py b/nums/experimental/utils.py new file mode 100644 index 00000000..a00dc544 --- /dev/null +++ b/nums/experimental/utils.py @@ -0,0 +1,44 @@ +import os +import gc + + +def check_block_integrity(arr): + for grid_entry in arr.grid.get_entry_iterator(): + assert arr.blocks[grid_entry].grid_entry == grid_entry + assert arr.blocks[grid_entry].rect == arr.grid.get_slice_tuples(grid_entry) + assert arr.blocks[grid_entry].shape == arr.grid.get_block_shape(grid_entry) + + +def benchmark_func(func, repeat=2, warmup=1): + for i in range(warmup): + gc.collect() + func() + + costs = [] + for i in range(repeat): + n = gc.collect() + # print(f"gc collect {n} objects") + cost, _ = func() + costs.append(cost) + n = gc.collect() + # print(f"end collect {n} objects") + + return costs + + +def get_number_of_gpus(): + val = os.popen('nvidia-smi --query-gpu=name --format=csv,noheader | wc -l').read() + return int(val) + + +def cupy_used_bytes(): + import cupy as cp + + mempool = cp.get_default_memory_pool() + return mempool.used_bytes() + + +if __name__ == "__main__": + print(f"Number of GPUS: {get_number_of_gpus()}") + +