-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
60 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,39 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# This demo shows the most basic usage of the napkinXC library. | ||
|
||
from napkinxc.datasets import load_dataset | ||
from napkinxc.models import PLT | ||
from napkinxc.measures import precision_at_k | ||
|
||
# Use load_dataset function to load one of the benchmark datasets | ||
# from XML Repository (http://manikvarma.org/downloads/XC/XMLRepository.html) | ||
# from XML Repository (http://manikvarma.org/downloads/XC/XMLRepository.html). | ||
X_train, Y_train = load_dataset("eurlex-4k", "train") | ||
X_test, Y_test = load_dataset("eurlex-4k", "test") | ||
|
||
# Create Probabilistic Labels Tree models, | ||
# directory "eurlex-model" will be created and used for model training and storing | ||
# Create Probabilistic Labels Tree model, | ||
# directory "eurlex-model" will be created and used during model training. | ||
# napkinXC stores already trained parts of the model to save RAM. | ||
# Model directory is only a required argument for model constructors. | ||
plt = PLT("eurlex-model") | ||
|
||
# Fit the model on the train dataset | ||
# Fit the model on the training dataset. | ||
# The model weights and additional data will be stored in "eurlex-model" directory. | ||
# Features matrix X must be SciPy csr_matrix, NumPy array, or list of tuples of (idx, value), | ||
# while labels matrix Y should be list of lists or tuples containing positive labels. | ||
plt.fit(X_train, Y_train) | ||
|
||
# Predict only the best label for each datapoint in the test dataset | ||
# After the training model is not loaded to RAM. | ||
# You can preload the model to RAM to perform prediction. | ||
plt.load() | ||
|
||
# Predict only the best label (top-1 label) for each data point in the test dataset. | ||
# This will also load the model if it is not loaded. | ||
Y_pred = plt.predict(X_test, top_k=1) | ||
|
||
# Evaluate precision at 1 | ||
# Evaluate the prediction with precision at 1 measure. | ||
print(precision_at_k(Y_test, Y_pred, k=1)) | ||
|
||
# Unload the model from RAM | ||
# You can also just delete the object if you do not need it | ||
plt.unload() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# This demo shows how to train, store and later load the napkinXC model. | ||
|
||
from napkinxc.datasets import load_dataset | ||
from napkinxc.models import PLT | ||
from napkinxc.measures import precision_at_k | ||
|
||
# The beginning is the same as in the basic.py example. | ||
|
||
# Use load_dataset function to load one of the benchmark datasets | ||
# from XML Repository (http://manikvarma.org/downloads/XC/XMLRepository.html). | ||
X_train, Y_train = load_dataset("eurlex-4k", "train") | ||
X_test, Y_test = load_dataset("eurlex-4k", "test") | ||
|
||
# Create PLT model with "eurlex-model" directory, | ||
# it will be created and used during model training for storing weights. | ||
# napkinXC stores already trained parts of the models to save RAM. | ||
plt = PLT("eurlex-model") | ||
|
||
# Fit the model on the training dataset. | ||
# The model weights and additional data will be stored in "eurlex-model" directory. | ||
plt.fit(X_train, Y_train) | ||
|
||
# Predict. | ||
Y_pred = plt.predict(X_test, top_k=1) | ||
print(precision_at_k(Y_test, Y_pred, k=1)) | ||
|
||
# Delete plt object. | ||
del plt | ||
|
||
# To load the model, create a new PLT object with the same directory as the previous one. | ||
new_plt = PLT("eurlex-model") | ||
|
||
# Predict using a new model object. | ||
Y_pred = new_plt.predict(X_test, top_k=1) | ||
print(precision_at_k(Y_test, Y_pred, k=1)) |