Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Grid search #154

Merged
merged 13 commits into from
Aug 8, 2023
Merged

Grid search #154

merged 13 commits into from
Aug 8, 2023

Conversation

msluszniak
Copy link
Contributor

@msluszniak msluszniak commented Aug 7, 2023

Closes #148

I also added some minor improvements in linear and ridge regression

@@ -105,9 +106,129 @@ defmodule Scholar.ModelSelection do
]
>
"""
def cross_validate(x, y, folding_fun, scoring_fun) do
def cross_validate(x, y, folding_fun, scoring_fun, opts \\ []) do
Copy link
Contributor

@josevalim josevalim Aug 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to support version without options?

scoring_fun =
cond do
  is_function(scoring_fun, 3) ->
    fn x, y -> scoring_fun.(x, y, opts) end

  is_function(scoring_fun, 2) and opts != [] -> 
    raise ArgumentError, "no options must be given if scoring_fun expects two arguments"

  is_function(scoring_fun, 2) -> 
    scoring_fun

  true ->
    raise ArgumentError, "expected scoring_fun to be a function of arity 2 or 3, got: #{inspect(scoring_fun)}"
end

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait a second... we don't process the opts in anyway. So I don't see why we need to pass them here? If you need to pass options to the scoring_fun, you can just close over it:

     iex> scoring_fun = fn x, y ->
      ...>   {x_train, x_test} = x
      ...>   {y_train, y_test} = y
      ...>   model = Scholar.Linear.LinearRegression.fit(x_train, y_train, some_option: 123)

I think we can remove the opts argument in the cross validation functions now that we have weighted cross validate.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also posted an example on grid search. :)

Copy link
Contributor

@josevalim josevalim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added some comments! Let's also add guards to those functions :)

sample_weights = Nx.tensor(sample_weights, type: x_type)

sample_weights =
if Nx.is_tensor(sample_weights),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to update the docs and add simple tests for these changes. :)

lib/scholar/model_selection.ex Outdated Show resolved Hide resolved
lib/scholar/model_selection.ex Outdated Show resolved Hide resolved
lib/scholar/model_selection.ex Outdated Show resolved Hide resolved
lib/scholar/model_selection.ex Outdated Show resolved Hide resolved
lib/scholar/model_selection.ex Outdated Show resolved Hide resolved
lib/scholar/model_selection.ex Outdated Show resolved Hide resolved
lib/scholar/model_selection.ex Outdated Show resolved Hide resolved
lib/scholar/model_selection.ex Outdated Show resolved Hide resolved
lib/scholar/model_selection.ex Outdated Show resolved Hide resolved
...> ]
iex> Scholar.ModelSelection.grid_search(x, y, folding_fun, scoring_fun, opts)
"""
def grid_search(x, y, folding_fun, scoring_fun, opts) when is_list(opts) do
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also add guards for scoring funs here and below. And please don't forget the tests for the weights as tensors :D

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I remember about them ;)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, there were tests for weights as a tensor in linear regression. I added a test case where types of data and weights differ.

@msluszniak msluszniak merged commit 736e640 into elixir-nx:main Aug 8, 2023
@msluszniak msluszniak deleted the grid_search branch August 8, 2023 19:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add Grid Search with Cross Validation
3 participants