This is an implementation of exploiting the generalized mean for per-task loss aggregation in multi-task learning. Our code is mainly based on LibMTL.
-
Create a virtual environment
conda create -n gemtl python=3.8 conda activate gemtl pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html
-
Clone this repository
-
Install
LibMTL
cd GeMTL pip install -e .
- Python >= 3.8
- Pytorch >= 1.8.1
pip install -r requirements.txt
You can download datasets in the following links.
Training and testing codes are in ./examples/{nyusp, office}/main.py
.
You can check the results by running the following command.
cd ./examples/{nyusp, office}
bash run.sh
Our implementation is developed on the following repositories. Thanks to the contributors!
This repository is released under the MIT license.