Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Models actions from v1 to v2 protocol #18

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import com.relationalai.HttpError;
import com.relationalai.Json;

public class DeleteModel implements Runnable {
public class DeleteModels implements Runnable {
String database, engine, model, profile;

public void parseArgs(String[] args) {
Expand All @@ -42,7 +42,7 @@ public void run(String[] args) throws HttpError, InterruptedException, IOExcepti
parseArgs(args);
var cfg = Config.loadConfig("~/.rai/config", this.profile);
var client = new Client(cfg);
var rsp = client.deleteModel(database, engine, model);
Json.print(rsp, 4);
var rsp = client.deleteModels(database, engine, new String[] {model});
System.out.println(rsp);
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;

import com.relationalai.Client;
import com.relationalai.Config;
import com.relationalai.HttpError;
import com.relationalai.Json;

public class LoadModel implements Runnable {
public class LoadModels implements Runnable {
String database, engine, filename, relation, profile;

// Returns the name of the file, without extension.
Expand Down Expand Up @@ -55,8 +58,11 @@ public void run(String[] args) throws HttpError, InterruptedException, IOExcepti
var cfg = Config.loadConfig("~/.rai/config", profile);
var client = new Client(cfg);
var name = sansext(filename);
var input = new FileInputStream(filename);
var rsp = client.loadModel(database, engine, name, input);
Json.print(rsp, 4);
var input = new String(new FileInputStream(filename).readAllBytes(), StandardCharsets.UTF_8);
var models = new HashMap<String, String>() {{
put(name, input);
}};
var rsp = client.loadModels(database, engine, models);
System.out.println(rsp);
}
}
148 changes: 85 additions & 63 deletions rai-sdk/src/main/java/com/relationalai/Client.java
Original file line number Diff line number Diff line change
Expand Up @@ -851,87 +851,109 @@ public Edb[] listEdbs(String database, String engine)

// Models

// Delete the named model.
public TransactionResult deleteModel(String database, String engine, String name)
// Delete the list of named models.
public TransactionAsyncResult deleteModels(String database, String engine, String[] names)
throws HttpError, InterruptedException, IOException {
var tx = new Transaction(this.region, database, engine, "OPEN");
var action = DbAction.makeDeleteModelAction(name);
var body = tx.payload(action);
var rsp = post(PATH_TRANSACTION, tx.queryParams(), body);
return Json.deserialize((String) rsp, TransactionResult.class);
var queries = new ArrayList<String>();
for (var name : names) {
queries.add(
String.format("def delete:rel:catalog:model[\"%s\"] = rel:catalog:model[\"%s\"]", name, name)
);
}

return execute(database, engine, String.join("\n", queries), false);
}

// Delete the list of named models.
public TransactionResult deleteModel(String database, String engine, String[] names)
public TransactionAsyncResult deleteModelsAsync(String database, String engine, String[] names)
throws HttpError, InterruptedException, IOException {
var tx = new Transaction(this.region, database, engine, "OPEN");
var actions = DbAction.makeDeleteModelsAction(names);
var body = tx.payload(actions);
var rsp = post(PATH_TRANSACTION, tx.queryParams(), body);
return Json.deserialize((String) rsp, TransactionResult.class);
var queries = new ArrayList<String>();
for (var name : names) {
queries.add(
String.format("def delete:rel:catalog:model[\"%s\"] = rel:catalog:model[\"%s\"]", name, name)
);
}

return executeAsync(database, engine, String.join("\n", queries), false);
}

// Return the named model.
public Model getModel(String database, String engine, String name)
throws HttpError, InterruptedException, IOException {
var models = listModels(database, engine);
for (var item : models) {
if (item.name.equals(name))
return item;
var outName = String.format("model_%d", new Random().nextInt(Integer.MAX_VALUE));
var query = String.format("def output:%s = rel:catalog:model[\"%s\"]", outName, name);

var resp = execute(database, engine, query, true);
var result = resp.results.stream().filter(
r -> r.relationId.equals(String.format("/:output/:%s/String", outName))
).findFirst().orElse(null);

if (result != null) {
return new Model(name, result.table.get(0).toString());
}

throw new HttpError(404);
}

// Load a model into the given database.
public TransactionResult loadModel(
String database, String engine, String name, InputStream model)
throws HttpError, InterruptedException, IOException {
var s = new String(model.readAllBytes());
return loadModel(database, engine, name, s);
// Load multiple models into the given database.
public TransactionAsyncResult loadModels(String database, String engine, Map<String, String> models) throws HttpError, IOException, InterruptedException {
var queries = new ArrayList<String>();
var queriesInputs = new HashMap<String, String>();
var randInt = new Random().nextInt(Integer.MAX_VALUE);

var index = 0;
for (var model : models.entrySet()) {
var inputName = String.format("input_%d_%d", randInt, index);
queries.add(
String.format("def delete:rel:catalog:model[\"%s\"] = rel:catalog:model[\"%s\"]", model.getKey(), model.getKey())
);
queries.add(
String.format("def insert:rel:catalog:model[\"%s\"] = %s", model.getKey(), inputName)
);
queriesInputs.put(inputName, model.getValue());
index++;
}
return execute(database, engine, String.join("\n", queries), false, queriesInputs);
}

public TransactionAsyncResult loadModelsAsync(String database, String engine, Map<String, String> models) throws HttpError, IOException, InterruptedException {
var queries = new ArrayList<String>();
var queriesInputs = new HashMap<String, String>();
var randInt = new Random().nextInt(Integer.MAX_VALUE);

var index = 0;
for (var model : models.entrySet()) {
var inputName = String.format("input_%d_%d", randInt, index);
queries.add(
String.format("def delete:rel:catalog:model[\"%s\"] = rel:catalog:model[\"%s\"]", model.getKey(), model.getKey())
);
queries.add(
String.format("def insert:rel:catalog:model[\"%s\"] = %s", model.getKey(), inputName)
);
queriesInputs.put(inputName, model.getValue());
index++;
}
return executeAsync(database, engine, String.join("\n", queries), false, queriesInputs);
}

public TransactionResult loadModel(
String database, String engine, String name, String model)
throws HttpError, InterruptedException, IOException {
var tx = new Transaction(this.region, database, engine, "OPEN", false);
var action = DbAction.makeInstallAction(name, model);
var data = tx.payload(action);
var rsp = post(PATH_TRANSACTION, tx.queryParams(), data);
return Json.deserialize((String) rsp, TransactionResult.class);
}
// Returns the list of models names installed in the given
// database.

// Load multiple models into the given database.
public TransactionResult loadModels(
String database, String engine, Map<String, String> models)
throws HttpError, InterruptedException, IOException {
var tx = new Transaction(this.region, database, engine, "OPEN", false);
var actions = DbAction.makeInstallAction(models);
var data = tx.payload(actions);
var rsp = post(PATH_TRANSACTION, tx.queryParams(), data);
return Json.deserialize((String) rsp, TransactionResult.class);
}
public List<String> listModels(String database, String engine) throws HttpError, IOException, InterruptedException {
var outName = String.format("models_%d", new Random().nextInt(Integer.MAX_VALUE));
var query = String.format("def output:%s[name] = rel:catalog:model(name, _)", outName);

// Returns the list of names of models installed in the given database.
public String[] listModelNames(String database, String engine)
throws HttpError, InterruptedException, IOException {
var models = listModels(database, engine);
String[] result = new String[models.length];
for (var i = 0; i < models.length; ++i)
result[i] = models[i].name;
return result;
}
var resp = execute(database, engine, query, true);
var result = resp.results.stream().filter(
r -> r.relationId.equals(String.format("/:output/:%s/String", outName))
).findFirst().orElse(null);

// Returns the list of models (including source) installed in the given
// database.
public Model[] listModels(String database, String engine)
throws HttpError, InterruptedException, IOException {
var tx = new Transaction(this.region, database, engine, "OPEN", true);
var body = tx.payload(DbAction.makeListModelsAction());
var rsp = post(PATH_TRANSACTION, tx.queryParams(), body);
var actions = Json.deserialize((String) rsp, ListModelsResponse.class).actions;
if (actions.length == 0)
return new Model[] {};
return actions[0].result.models;
if (result != null) {
return result.table.stream()
.map(elem -> elem.toString())
.collect(Collectors.toList());
}

return new ArrayList<String>();
}

// Data loading
Expand Down
34 changes: 14 additions & 20 deletions rai-sdk/src/test/java/com/relationalai/DatabaseTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -78,15 +80,11 @@ void testDatabase() throws HttpError, InterruptedException, IOException {
var edb = find(edbs, item -> item.name.equals("rel"));
assertNotNull(edb);

var modelNames = client.listModelNames(databaseName, engineName);
var name = find(modelNames, item -> item.equals("stdlib"));
var modelNames = client.listModels(databaseName, engineName);
var name = modelNames.stream().filter(n -> n.equals("rel/alglib")).findFirst().orElse(null);
assertNotNull(name);

var models = client.listModels(databaseName, engineName);
var model = find(models, m -> m.name.equals("stdlib"));
assertNotNull(model);

model = client.getModel(databaseName, engineName, "stdlib");
var model = client.getModel(databaseName, engineName, "rel/alglib");
assertNotNull(model);
assertTrue(model.value.length() > 0);

Expand All @@ -106,8 +104,9 @@ void testDatabase() throws HttpError, InterruptedException, IOException {
assertNull(database);
}

static final String testModel =
"def R = \"hello\", \"world\"";
static final Map<String, String> testModel = new HashMap<String, String>(){{
put("test_model", "def R = \"hello\", \"world\"");
}};

static final String testJson = "{" +
"\"name\":\"Amira\",\n" +
Expand Down Expand Up @@ -137,10 +136,9 @@ void testDatabase() throws HttpError, InterruptedException, IOException {
assertEquals(0, loadRsp.output.length);
assertEquals(0, loadRsp.problems.length);

loadRsp = client.loadModel(databaseName, engineName, "test_model", testModel);
assertEquals(false, loadRsp.aborted);
assertEquals(0, loadRsp.output.length);
assertEquals(0, loadRsp.problems.length);
var resp = client.loadModels(databaseName, engineName, testModel);
assertEquals("COMPLETED", resp.transaction.state);
assertEquals(0, resp.problems.size());

// Clone the database
var databaseCloneName = databaseName + "-clone";
Expand Down Expand Up @@ -177,15 +175,11 @@ void testDatabase() throws HttpError, InterruptedException, IOException {
assertNotNull(rel);

// Make sure the model was cloned
var modelNames = client.listModelNames(databaseName, engineName);
var name = find(modelNames, item -> item.equals("test_model"));
var modelNames = client.listModels(databaseName, engineName);
var name = modelNames.stream().filter(n -> n.equals("test_model")).findFirst().orElse(null);
assertNotNull(name);

var models = client.listModels(databaseName, engineName);
var model = find(models, m -> m.name.equals("test_model"));
assertNotNull(model);

model = client.getModel(databaseName, engineName, "test_model");
var model = client.getModel(databaseName, engineName, "test_model");
assertNotNull(model);
assertTrue(model.value.length() > 0);

Expand Down
Loading