Skip to content

Commit

Permalink
Merge pull request #86 from clEsperanto/add-wait-kernel-func
Browse files Browse the repository at this point in the history
Add-wait-kernel-func
  • Loading branch information
StRigaud authored Oct 17, 2024
2 parents e3ec186 + 804ac7d commit c212ef5
Show file tree
Hide file tree
Showing 10 changed files with 962 additions and 895 deletions.
10 changes: 6 additions & 4 deletions native/clesperantoj/include/clesperantoj.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@ class DeviceJ
std::string getName() const;
std::string getInfo() const;

void setWaitForKernelFinish(bool flag);

std::shared_ptr<cle::Device> get() const;

bool operator==(const DeviceJ& other) const;
bool operator==(const DeviceJ &other) const;
};

enum class DTypeJ
Expand Down Expand Up @@ -70,8 +72,8 @@ class ArrayJ
friend class MemoryJ;

protected:
void writeFrom( void *data, const size_t &origin_x, const size_t &origin_y, const size_t &origin_z, const size_t &width, const size_t &height, const size_t &depth) const;
void readTo( void *data, const size_t &origin_x, const size_t &origin_y, const size_t &origin_z, const size_t &width, const size_t &height, const size_t &depth) const;
void writeFrom(void *data, const size_t &origin_x, const size_t &origin_y, const size_t &origin_z, const size_t &width, const size_t &height, const size_t &depth) const;
void readTo(void *data, const size_t &origin_x, const size_t &origin_y, const size_t &origin_z, const size_t &width, const size_t &height, const size_t &depth) const;

public:
ArrayJ() = default;
Expand All @@ -87,7 +89,7 @@ class ArrayJ

std::shared_ptr<cle::Array> get() const;

static ArrayJ create( const size_t &width, const size_t &height, const size_t &depth, const size_t &dimension, const DTypeJ &data_type, const MTypeJ &memory_type, const DeviceJ &device);
static ArrayJ create(const size_t &width, const size_t &height, const size_t &depth, const size_t &dimension, const DTypeJ &data_type, const MTypeJ &memory_type, const DeviceJ &device);
DTypeJ dtype() const;
MTypeJ mtype() const;
DeviceJ device() const;
Expand Down
77 changes: 50 additions & 27 deletions native/clesperantoj/src/clesperantoj.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ DeviceJ::DeviceJ(const std::shared_ptr<cle::Device> &device) : device_(device)
{
}

bool DeviceJ::operator==(const DeviceJ& other) const
bool DeviceJ::operator==(const DeviceJ &other) const
{
return (device_ == other.device_);
}
Expand All @@ -48,6 +48,11 @@ std::string DeviceJ::getInfo() const
return this->device_->getInfo();
}

void DeviceJ::setWaitForKernelFinish(bool flag)
{
this->device_->setWaitToFinish(flag);
}

std::shared_ptr<cle::Device> DeviceJ::get() const
{
return this->device_;
Expand Down Expand Up @@ -92,7 +97,6 @@ void ArrayJ::copyDataTo(ArrayJ &dst)
this->array_->copyTo(dst.get());
}


std::vector<std::string> UtilsJ::getKeys(const std::unordered_map<std::string, std::vector<float>> &map)
{
std::vector<std::string> keys;
Expand All @@ -115,51 +119,71 @@ inline cle::dType to_cle_dType(const DTypeJ &dtype)
{
switch (dtype)
{
case DTypeJ::INT8: return cle::dType::INT8;
case DTypeJ::UINT8: return cle::dType::UINT8;
case DTypeJ::INT16: return cle::dType::INT16;
case DTypeJ::UINT16: return cle::dType::UINT16;
case DTypeJ::INT32: return cle::dType::INT32;
case DTypeJ::UINT32: return cle::dType::UINT32;
case DTypeJ::FLOAT: return cle::dType::FLOAT;
case DTypeJ::UNKNOWN:
default: return cle::dType::UNKNOWN;
case DTypeJ::INT8:
return cle::dType::INT8;
case DTypeJ::UINT8:
return cle::dType::UINT8;
case DTypeJ::INT16:
return cle::dType::INT16;
case DTypeJ::UINT16:
return cle::dType::UINT16;
case DTypeJ::INT32:
return cle::dType::INT32;
case DTypeJ::UINT32:
return cle::dType::UINT32;
case DTypeJ::FLOAT:
return cle::dType::FLOAT;
case DTypeJ::UNKNOWN:
default:
return cle::dType::UNKNOWN;
}
}

inline DTypeJ from_cle_dType(const cle::dType &dtype)
{
switch (dtype)
{
case cle::dType::INT8: return DTypeJ::INT8;
case cle::dType::UINT8: return DTypeJ::UINT8;
case cle::dType::INT16: return DTypeJ::INT16;
case cle::dType::UINT16: return DTypeJ::UINT16;
case cle::dType::INT32: return DTypeJ::INT32;
case cle::dType::UINT32: return DTypeJ::UINT32;
case cle::dType::FLOAT: return DTypeJ::FLOAT;
// case cle::dType::UNKNOWN: // TODO: uncomment after https://github.com/clEsperanto/CLIc/pull/353 is merged, released, and we depend on that version
default: return DTypeJ::UNKNOWN;
case cle::dType::INT8:
return DTypeJ::INT8;
case cle::dType::UINT8:
return DTypeJ::UINT8;
case cle::dType::INT16:
return DTypeJ::INT16;
case cle::dType::UINT16:
return DTypeJ::UINT16;
case cle::dType::INT32:
return DTypeJ::INT32;
case cle::dType::UINT32:
return DTypeJ::UINT32;
case cle::dType::FLOAT:
return DTypeJ::FLOAT;
// case cle::dType::UNKNOWN: // TODO: uncomment after https://github.com/clEsperanto/CLIc/pull/353 is merged, released, and we depend on that version
default:
return DTypeJ::UNKNOWN;
}
}

inline cle::mType to_cle_mType(const MTypeJ &mtype)
{
switch (mtype)
{
case MTypeJ::IMAGE: return cle::mType::IMAGE;
case MTypeJ::BUFFER:
default: return cle::mType::BUFFER;
case MTypeJ::IMAGE:
return cle::mType::IMAGE;
case MTypeJ::BUFFER:
default:
return cle::mType::BUFFER;
}
}

inline MTypeJ from_cle_mType(const cle::mType &mtype)
{
switch (mtype)
{
case cle::mType::IMAGE: return MTypeJ::IMAGE;
case cle::mType::BUFFER:
default: return MTypeJ::BUFFER;
case cle::mType::IMAGE:
return MTypeJ::IMAGE;
case cle::mType::BUFFER:
default:
return MTypeJ::BUFFER;
}
}

Expand Down Expand Up @@ -258,4 +282,3 @@ void MemoryJ::writeFromInt(const ArrayJ &array, int *data, const size_t &origin_
{
array.writeFrom(static_cast<void *>(data), origin_x, origin_y, origin_z, width, height, depth);
}

97 changes: 63 additions & 34 deletions src/main/java/net/clesperanto/core/DeviceJ.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@
import net.clesperanto._internals.jclic._StringVector;

/**
* Class to interact with the divide that is going to be used to do the operations
* Class to interact with the divide that is going to be used to do the
* operations
*/
public class DeviceJ {

Expand All @@ -60,12 +61,14 @@ private DeviceJ() {
* Constructor that initializes the wanted device
* IMPORTANT: Does not initialize the backend.
*
* TODO provide a better explanation of what deviceName is and what device type is
* TODO provide a better explanation of what deviceName is and what device type
* is
*
* @param deviceName
* the name of the device that wants to be initialized
* the name of the device that wants to be initialized
* @param deviceType
* the type that wants to be initialized. If any type works, the argument should be "all"
* the type that wants to be initialized. If any type works,
* the argument should be "all"
*/
private DeviceJ(String deviceName, String deviceType) {
Objects.requireNonNull(deviceName, "The device name cannot be null");
Expand All @@ -79,8 +82,10 @@ private DeviceJ(String deviceName, String deviceType) {
}

/**
* Get the first device available. By default this method tries to use an OpenCL backend.
* Get the first device available. By default this method tries to use an OpenCL
* backend.
* If not OpenCL backend is available the method will fail.
*
* @return the default device where Clesperanto operations can be done.
*/
public static DeviceJ getDefaultDevice() {
Expand All @@ -92,9 +97,11 @@ public static DeviceJ getDefaultDevice() {
* Get the first device available and select the wanted backend.
* If the wanted backend is not available, the method will fall back to OpenCL.
* And if OpenCL is not available either the method will fail.
*
* @param backend
* the type of backend that wants to be used. It should be either "cuda" or "opencl",
* if it is anything else it will be set to "opencl"
* the type of backend that wants to be used. It should be either
* "cuda" or "opencl",
* if it is anything else it will be set to "opencl"
* @return the default device where Clesperanto operations can be done.
*/
public static DeviceJ getDefaultDevice(String backend) {
Expand All @@ -103,15 +110,18 @@ public static DeviceJ getDefaultDevice(String backend) {
}

/**
* Get the wanted device by its name and device type. Initialize the device with openCL backend.
* Get the wanted device by its name and device type. Initialize the device with
* openCL backend.
* If not OpenCL backend is available the method will fail.
*
* TODO provide a better explanation of what deviceName is and what device type is
* TODO provide a better explanation of what deviceName is and what device type
* is
*
* @param deviceName
* the name of the device that wants to be initialized
* the name of the device that wants to be initialized
* @param deviceType
* the type that wants to be initialized. If any type works, the argument should be "all"
* the type that wants to be initialized. If any type works,
* the argument should be "all"
* @return the wanted device where Clesperanto operations can be done.
*/
public static DeviceJ getDeviceWithDefaultBackend(String deviceName, String deviceType) {
Expand All @@ -121,18 +131,22 @@ public static DeviceJ getDeviceWithDefaultBackend(String deviceName, String devi

/**
* Get the wanted device by its name and device type. Initialize the device with
* the wanted backend. If the backend is not available it will fallback to OpenCl backend.
* the wanted backend. If the backend is not available it will fallback to
* OpenCl backend.
* If OpenCL backend is not available the method will fail.
*
* TODO provide a better explanation of what deviceName is and what device type is
* TODO provide a better explanation of what deviceName is and what device type
* is
*
* @param deviceName
* the name of the device that wants to be initialized
* the name of the device that wants to be initialized
* @param deviceType
* the type that wants to be initialized. If any type works, the argument should be "all"
* the type that wants to be initialized. If any type works,
* the argument should be "all"
* @param backend
* the type of backend that wants to be used. It should be either "cuda" or "opencl",
* if it is anything else it will be set to "opencl"
* the type of backend that wants to be used. It should be
* either "cuda" or "opencl",
* if it is anything else it will be set to "opencl"
* @return the wanted device where Clesperanto operations can be done.
*/
public static DeviceJ getDevice(String deviceName, String deviceType, String backend) {
Expand All @@ -156,42 +170,53 @@ public String getInfo() {
return this._deviceJ.getInfo();
}

/**
*
* @return void
*/
public void setWaitForKernelFinish(boolean wait) {
this._deviceJ.setWaitForKernelFinish(wait);
}

/**
* TODO confirm if the devices are only GPUs or can be other hardware
* Method that returns the available devices (GPUs) on the computer.
*
* @return a list of the available devices in the computer
*/
public static List<String> getAvailableDevices(){
public static List<String> getAvailableDevices() {
_StringVector devices = _DeviceJ.getAvailableDevices();
List<String> devicesList = new ArrayList<String>();
for (int i = 0; i < devices.size(); i++) {
devicesList.add(devices.get(i));
}
}
return devicesList;
}

/**
* TODO confirm if the devices are only GPUs or can be other hardware
* Method that returns the available devices (GPUs) of the given {@code deviceType} on the computer.
* Using the {@code deviceType} "all" returns all the devices available, it is the same as using {@link #getAvailableDevices()}.
* Method that returns the available devices (GPUs) of the given
* {@code deviceType} on the computer.
* Using the {@code deviceType} "all" returns all the devices available, it is
* the same as using {@link #getAvailableDevices()}.
*
* @param deviceType the type of device to look for
* @return a list of the available devices in the computer of the specific type
*/
public static List<String> getAvailableDevices(String deviceType){
public static List<String> getAvailableDevices(String deviceType) {
Objects.requireNonNull(deviceType, "The device type cannot be null, if any device type works, use \"all\" or"
+ " use the method \"DeviceJ.getAvailableDevices()\"");
_StringVector devices = _DeviceJ.getAvailableDevices(deviceType);
List<String> devicesList = new ArrayList<String>();
for (int i = 0; i < devices.size(); i++) {
devicesList.add(devices.get(i));
}
}
return devicesList;
}

/**
* Return the backend that the device is using.
*
* @return the backend (opencl, cuda) that the device is using
* @throws RuntimeException if there is any error finding the backend
*/
Expand All @@ -200,20 +225,23 @@ public String getBackend() {

int ind = info.indexOf(")");

if (ind == -1) throw new RuntimeException("Unable to retrieve backend");
if (ind == -1)
throw new RuntimeException("Unable to retrieve backend");

return info.substring(1, ind).toLowerCase();
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (!(o instanceof DeviceJ)) return false;
if (this == o)
return true;
if (!(o instanceof DeviceJ))
return false;
DeviceJ deviceJ = (DeviceJ) o;
return _deviceJ.equals(deviceJ._deviceJ);
// NB: _DeviceJ.equals method is a native method overloading (not
// overriding) Object.equals() Therefore, Objects.equals(_deviceJ,
// deviceJ._deviceJ) will *not* work here!
// overriding) Object.equals() Therefore, Objects.equals(_deviceJ,
// deviceJ._deviceJ) will *not* work here!
}

@Override
Expand Down Expand Up @@ -243,10 +271,11 @@ public ArrayJ createArray(final DataType dataType, final MemoryType memoryType,
}

/**
*
* @return the raw object that is going to be sent to the native Clesperanto library. Without Java wrappers
*/
public _DeviceJ getRaw() {
return this._deviceJ;
}
*
* @return the raw object that is going to be sent to the native Clesperanto
* library. Without Java wrappers
*/
public _DeviceJ getRaw() {
return this._deviceJ;
}
}
Loading

0 comments on commit c212ef5

Please sign in to comment.