Easy-to-use PyTorch library for cross-domain learning, few-shot learning and meta-learning.
TorchCross is a PyTorch library for cross-domain learning, few-shot learning and meta-learning. It provides convenient utilities for creating cross-domain learning or few-shot learning experiments.
torchcross
: The main package, containing the library.torchcross.cd
: Contains functions to create network heads, losses and metrics for cross-domain learning experiments.torchcross.data
: Contains classes to load data as for cross-domain learning or few-shot learning experiments.torchcross.data.task
: Contains theTask
andTaskDescription
classes, which represent a task in a few-shot learning scenario and a task's metadata, respectively.torchcross.data.task_source
: Contains theTaskSource
class, which extends thetorch.utils.data.Dataset
class and represents a data source for sampling tasks in a few-shot learning scenario. Additionally, it contains the utility classes which facilitate working withTaskSource
objects.torchcross.data.metadataset
: Contains theMetaDataset
class, which represents a collection of tasks in a few-shot learning scenario. TheFewShotMetaDataset
andSubTaskRandomFewShotMetaDataset
classes extend theMetaDataset
class and provide a convenient interface for sampling few-shot instances from aTaskSource
.
torchcross.models
: Containstorch.nn.Module
classes for cross-domain and few-shot scenarios.torchcross.models.lightning
: Containspytorch_lightning.LightningModule
classes which wrap thetorch.nn.Module
classes intorchcross.models
and provide convenient training and evaluation routines.torchcross.utils
: Contains various utility functions.
This library is still in alpha. The API is potentially subject to change. Any feedback is welcome.
The library can be installed via pip:
pip install torchcross
The WrapTaskSource
class can be used to wrap a PyTorch dataset as a TaskSource
.
import torch.utils.data
from torchcross.data import TaskTarget, WrapTaskSource
# Create a PyTorch dataset
dataset: torch.utils.data.Dataset = ...
# Define the appropriate TaskTarget. Let's assume that the task is a multi-class
# classification task.
task_target = TaskTarget.MULTICLASS_CLASSIFICATION
# Classes can be provided as a dictionary. Let's assume that the dataset contains
# four classes with names "Class A", "Class B", "Class C" and "Class D".
classes = {0: "Class A", 1: "Class B", 2: "Class C", 3: "Class D"}
# Wrap the dataset as a TaskSource
task_source = WrapTaskSource(dataset, task_target, classes)
The FewShotMetaDataset
class can be used to create a few-shot meta-dataset from a
TaskSource
.
from torchcross.data.metadataset import FewShotMetaDataset
import torchcross as tx
# Create a TaskSource.
task_source = ...
# Create a few-shot meta-dataset from the task source. Let's assume that we want to
# sample 100 tasks, each containing 5 support samples and 10 query samples.
# Let's use the stack function from the tx.utils module to stack the support and query
# samples into a single tensor.
meta_dataset = FewShotMetaDataset(
task_source,
collate_fn=tx.utils.collate_fn.stack,
n_support_samples_per_class=5,
n_query_samples_per_class=10,
filter_classes_min_samples=30,
length=100,
)
The SubTaskRandomFewShotMetaDataset
class can be used to create a few-shot
meta-dataset from a TaskSource
where each task is a random sub-task of the original
task, meaning that each few-shot task contains a random subset of the original task's
classes.
from torchcross.data.metadataset import SubTaskRandomFewShotMetaDataset
import torchcross as tx
# Create a TaskSource.
task_source = ...
# Create a few-shot meta-dataset from the task source. Let's assume that we want to
# sample 100 tasks, each containing a random subset of the original task's classes and
# 1 to 10 support samples per class and 10 query samples per class.
# Let's use the stack function from the tx.utils module to stack the support and query
# samples into a single tensor.
few_shot = SubTaskRandomFewShotMetaDataset(
task_source,
collate_fn=tx.utils.collate_fn.stack,
n_support_samples_per_class_min=1,
n_support_samples_per_class_max=10,
n_query_samples_per_class=10,
filter_classes_min_samples=30,
length=100,
)
The lightning.LightningModule
subclasses in torchcross.models.lightning
can be used
to train a model on a meta-dataset. Let's assume that we want to train a model on a
few-shot meta-dataset using the MAML algorithm. We can use the MAML
class from the
torchcross.models.lightning
module to do so.
from functools import partial
import lightning.pytorch as pl
import torch
import torchopt
from torch.utils.data import DataLoader
from torchcross.data import TaskDescription, TaskTarget
from torchcross.models.lightning import MAML
from torchcross.utils.collate_fn import identity
# Create few-shot meta-datasets for training and validation.
train_meta_dataset = ...
val_meta_dataset = ...
# Create dataloaders for training and validation.
train_dataloader = DataLoader(train_meta_dataset, batch_size=4, collate_fn=identity)
val_dataloader = DataLoader(val_meta_dataset, batch_size=4, collate_fn=identity)
# Create a `TaskDescription` object which describes the task type of task that we want
# to solve. Let's assume that the task is a multi-class classification task with four
# classes.
task_description = TaskDescription(
TaskTarget.MULTICLASS_CLASSIFICATION,
classes={0: "Class A", 1: "Class B", 2: "Class C", 3: "Class D"},
)
# Create a backbone network. We could for example use a pre-trained resnet18 backbone.
backbone = ...
num_backbone_output_features = 512
# Create inner and outer optimizers and hyperparameters for MAML.
outer_optimizer = partial(torch.optim.Adam, lr=0.001)
inner_optimizer = partial(torchopt.MetaSGD, lr=0.1)
eval_inner_optimizer = partial(torch.optim.SGD, lr=0.1)
num_inner_steps = 4
eval_num_inner_steps = 32
# Create the lighting model with pre-trained resnet18 backbone
model = MAML(
(backbone, num_backbone_output_features),
task_description,
outer_optimizer,
inner_optimizer,
eval_inner_optimizer,
num_inner_steps,
eval_num_inner_steps,
)
# Create the lightning trainer
trainer = pl.Trainer(
inference_mode=False,
max_epochs=10,
check_val_every_n_epoch=1,
val_check_interval=1000,
limit_val_batches=100,
)
# Train the model
trainer.fit(model, train_dataloader, val_dataloader)
See the MIMeta library for a real-world example of how to use TorchCross to perform cross-domain learning, few-shot learning or meta-learning experiments.