Skip to content

Latest commit

 

History

History
299 lines (230 loc) · 16 KB

README.md

File metadata and controls

299 lines (230 loc) · 16 KB

PyPi version Anaconda version

News

November 2: v2.7.0

July 24: v2.6.0

  • Changed the emb argument of DistributedLossWrapper.forward to embeddings to be consistent with the rest of the library.
  • Added a warning and early-return when DistributedLossWrapper is being used in a non-distributed setting.
  • Thank you elisim.

Documentation

Google Colab Examples

See the examples folder for notebooks you can download or run on Google Colab.

PyTorch Metric Learning Overview

This library contains 9 modules, each of which can be used independently within your existing codebase, or combined together for a complete train/test workflow.

high_level_module_overview

How loss functions work

Using losses and miners in your training loop

Let’s initialize a plain TripletMarginLoss:

from pytorch_metric_learning import losses
loss_func = losses.TripletMarginLoss()

To compute the loss in your training loop, pass in the embeddings computed by your model, and the corresponding labels. The embeddings should have size (N, embedding_size), and the labels should have size (N), where N is the batch size.

# your training loop
for i, (data, labels) in enumerate(dataloader):
	optimizer.zero_grad()
	embeddings = model(data)
	loss = loss_func(embeddings, labels)
	loss.backward()
	optimizer.step()

The TripletMarginLoss computes all possible triplets within the batch, based on the labels you pass into it. Anchor-positive pairs are formed by embeddings that share the same label, and anchor-negative pairs are formed by embeddings that have different labels.

Sometimes it can help to add a mining function:

from pytorch_metric_learning import miners, losses
miner = miners.MultiSimilarityMiner()
loss_func = losses.TripletMarginLoss()

# your training loop
for i, (data, labels) in enumerate(dataloader):
	optimizer.zero_grad()
	embeddings = model(data)
	hard_pairs = miner(embeddings, labels)
	loss = loss_func(embeddings, labels, hard_pairs)
	loss.backward()
	optimizer.step()

In the above code, the miner finds positive and negative pairs that it thinks are particularly difficult. Note that even though the TripletMarginLoss operates on triplets, it’s still possible to pass in pairs. This is because the library automatically converts pairs to triplets and triplets to pairs, when necessary.

Customizing loss functions

Loss functions can be customized using distances, reducers, and regularizers. In the diagram below, a miner finds the indices of hard pairs within a batch. These are used to index into the distance matrix, computed by the distance object. For this diagram, the loss function is pair-based, so it computes a loss per pair. In addition, a regularizer has been supplied, so a regularization loss is computed for each embedding in the batch. The per-pair and per-element losses are passed to the reducer, which (in this diagram) only keeps losses with a high value. The averages are computed for the high-valued pair and element losses, and are then added together to obtain the final loss.

high_level_loss_function_overview

Now here's an example of a customized TripletMarginLoss:

from pytorch_metric_learning.distances import CosineSimilarity
from pytorch_metric_learning.reducers import ThresholdReducer
from pytorch_metric_learning.regularizers import LpRegularizer
from pytorch_metric_learning import losses
loss_func = losses.TripletMarginLoss(distance = CosineSimilarity(), 
				     reducer = ThresholdReducer(high=0.3), 
			 	     embedding_regularizer = LpRegularizer())

This customized triplet loss has the following properties:

  • The loss will be computed using cosine similarity instead of Euclidean distance.
  • All triplet losses that are higher than 0.3 will be discarded.
  • The embeddings will be L2 regularized.

Using loss functions for unsupervised / self-supervised learning

A SelfSupervisedLoss wrapper is provided for self-supervised learning:

from pytorch_metric_learning.losses import SelfSupervisedLoss
loss_func = SelfSupervisedLoss(TripletMarginLoss())

# your training for-loop
for i, data in enumerate(dataloader):
	optimizer.zero_grad()
	embeddings = your_model(data)
	augmented = your_model(your_augmentation(data))
	loss = loss_func(embeddings, augmented)
	loss.backward()
	optimizer.step()

If you're interested in MoCo-style self-supervision, take a look at the MoCo on CIFAR10 notebook. It uses CrossBatchMemory to implement the momentum encoder queue, which means you can use any tuple loss, and any tuple miner to extract hard samples from the queue.

Highlights of the rest of the library

  • For a convenient way to train your model, take a look at the trainers.
  • Want to test your model's accuracy on a dataset? Try the testers.
  • To compute the accuracy of an embedding space directly, use AccuracyCalculator.

If you're short of time and want a complete train/test workflow, check out the example Google Colab notebooks.

To learn more about all of the above, see the documentation.

Installation

Required PyTorch version

  • pytorch-metric-learning >= v0.9.90 requires torch >= 1.6
  • pytorch-metric-learning < v0.9.90 doesn't have a version requirement, but was tested with torch >= 1.2

Other dependencies: numpy, scikit-learn, tqdm, torchvision

Pip

pip install pytorch-metric-learning

To get the latest dev version:

pip install pytorch-metric-learning --pre

To install on Windows:

pip install torch===1.6.0 torchvision===0.7.0 -f https://download.pytorch.org/whl/torch_stable.html
pip install pytorch-metric-learning

To install with evaluation and logging capabilities

(This will install the unofficial pypi version of faiss-gpu, plus record-keeper and tensorboard):

pip install pytorch-metric-learning[with-hooks]

To install with evaluation and logging capabilities (CPU)

(This will install the unofficial pypi version of faiss-cpu, plus record-keeper and tensorboard):

pip install pytorch-metric-learning[with-hooks-cpu]

Conda

conda install -c conda-forge pytorch-metric-learning

To use the testing module, you'll need faiss, which can be installed via conda as well. See the installation instructions for faiss.

Benchmark results

See powerful-benchmarker to view benchmark results and to use the benchmarking tool.

Development

Development is done on the dev branch:

git checkout dev

Unit tests can be run with the default unittest library:

python -m unittest discover

You can specify the test datatypes and test device as environment variables. For example, to test using float32 and float64 on the CPU:

TEST_DTYPES=float32,float64 TEST_DEVICE=cpu python -m unittest discover

To run a single test file instead of the entire test suite, specify the file name:

python -m unittest tests/losses/test_angular_loss.py

Code is formatted using black and isort:

pip install black isort
./format_code.sh

Acknowledgements

Contributors

Thanks to the contributors who made pull requests!

Contributor Highlights
domenicoMuscill0 - ManifoldLoss
- P2SGradLoss
- HistogramLoss
- DynamicSoftMarginLoss
- RankedListLoss
mlopezantequera - Made the testers work on any combination of query and reference sets
- Made AccuracyCalculator work with arbitrary label comparisons
cwkeam - SelfSupervisedLoss
- VICRegLoss
- Added mean reciprocal rank accuracy to AccuracyCalculator
- BaseLossWrapper
marijnl - BatchEasyHardMiner
- TwoStreamMetricLoss
- GlobalTwoStreamEmbeddingSpaceTester
- Example using trainers.TwoStreamMetricLoss
ir2718 ThresholdConsistentMarginLoss
chingisooinar SubCenterArcFaceLoss
elias-ramzi HierarchicalSampler
fjsj SupConLoss
AlenUbuntu CircleLoss
interestingzhuo PNPLoss
wconnell Learning a scRNAseq Metric Embedding
mkmenta Improved get_all_triplets_indices (fixed the INT_MAX error)
AlexSchuy optimized utils.loss_and_miner_utils.get_random_triplet_indices
JohnGiorgi all_gather in utils.distributed
Hummer12007 utils.key_checker
vltanh Made InferenceModel.train_indexer accept datasets
btseytlin get_nearest_neighbors in InferenceModel
mlw214 Added return_per_class to AccuracyCalculator
layumi InstanceLoss
NoTody Helped add ref_emb and ref_labels to the distributed wrappers.
ElisonSherton Fixed an edge case in ArcFaceLoss.
stompsjo Improved documentation for NTXentLoss.
Puzer Bug fix for PNPLoss.
elisim Developer improvements to DistributedLossWrapper.
GaetanLepage
z1w
thinline72
tpanum
fralik
joaqo
JoOkuma
gkouros
yutanakamura-tky
KinglittleQ
martin0258
michaeldeyzel
HSinger04
rheum
bot66

Facebook AI

Thank you to Ser-Nam Lim at Facebook AI, and my research advisor, Professor Serge Belongie. This project began during my internship at Facebook AI where I received valuable feedback from Ser-Nam, and his team of computer vision and machine learning engineers and research scientists. In particular, thanks to Ashish Shah and Austin Reiter for reviewing my code during its early stages of development.

Open-source repos

This library contains code that has been adapted and modified from the following great open-source repos:

Logo

Thanks to Jeff Musgrave for designing the logo.

Citing this library

If you'd like to cite pytorch-metric-learning in your paper, you can use this bibtex:

@article{Musgrave2020PyTorchML,
  title={PyTorch Metric Learning},
  author={Kevin Musgrave and Serge J. Belongie and Ser-Nam Lim},
  journal={ArXiv},
  year={2020},
  volume={abs/2008.09164}
}