-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Add easy model fitting and comparison #77
base: main
Are you sure you want to change the base?
Conversation
@@ -532,3 +536,100 @@ def _convert_binary_to_multiclass(self, predictions: np.ndarray, dataset: str) - | |||
return np.stack([1 - predictions, predictions], axis=predictions.ndim) | |||
else: | |||
return predictions | |||
|
|||
def _convert_time_infer_s_from_sample_to_batch(self, df: pd.DataFrame): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we move those functions to an util? Those are very specific function regarding to one analysis, if we all our utils in AbstractRepository, it will be very hard to maintain.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, a lot of this code is currently WIP. Basically, we are prioritizing making it work first, and then once it works we will do clean up
""" | ||
Class to Fetch Train Test Splits of context dataset | ||
""" | ||
def get_context_train_test_split(self, repo: EvaluationRepository, task_id: Union[int, List[int]], repeat: int = 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice to get those!
from sklearn.utils.multiclass import unique_labels | ||
|
||
|
||
class TabForestPFN_sklearn(BaseEstimator, ClassifierMixin): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have not touched this class as it still seems to be a WIP
Description of changes:
This PR adds plotting functionality comparing the TabRepo configs vs AG fitted models.
Following are the changes compared to the earlier PR, i.e., : #76
The
convert_leaderboard_to_configs()
is modified and cleaned up, the earlier one renamed columns which were not present in the DataFrame.Addition of
plot_overall_rank_comparison()
which plots various figures for all the models in the DataFrame i.e., (fitted models + TabRepo configs)NOTE that in the earlier #76, there is a fold mismatch in purpose to test the functionality of
compare_metrics()
, but in this PRtemp_script.py
has the same folds for both fitted models and TabRepo configs, this is done just to keep the folds same while plotting.The plot function breaks when ELO figures are plotted and the code will give an error, still a WIP, but rest of the plots can be found in
initial_experiment/output/figures
, the code runs up-to that mark.By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.