-
Notifications
You must be signed in to change notification settings - Fork 61
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Does this library support Gaussian Process regression from PyTorch or TensorFlow #127
Comments
Hello, thanks for the question! At the moment, RTNeural does not support either Gaussian Processes. It should be possible to implement a Gaussian Process in RTNeural, that could take an explicitly defined covariance function, with parameters as determined during the "training" process. There's a few problems that I'm seeing with regards to how that implementation would work:
Anyway, these are just my initial thoughts. If it seems like there's some aspects of the Gaussian Process that I'm not quite grasping, or if there's more relevant information you can provide, that would be great! Thanks, |
Im using GPR with the RBF kernel which is squared exponential as well as a constant kernel and a white kernel. When I export to onnx I can see the types of calculations performed as well as the weights etc inside, Im not aware of any random number generators amongst the operators, not in the inference part. Heres a list of the operators used in the scikit GPR: Operators used in the model: {'Mul', 'MatMul', 'OneHotEncoder', 'CDist', 'Div', 'Exp', 'Add', 'Shape', 'ConstantOfShape', 'Cast', 'Concat', 'ArrayFeatureExtractor', 'Scaler', 'Reshape', 'ReduceSum', 'Transpose'} Im doing one hot encoding and standard scaling to thats the ArrayFeatureExtractor, OneHotEncoder, and Scaler. For the GPR part It looks like its using Div, CDist, Mul, Exp, MatMul, and Add. I think it would be a nice addition to the library as simple and more advanced regression probelms could be included rather than having to include the onnx runtime statically which is quite a large dependency, ~70MB with the excell operators removed. In terms of Pytorch, and Tensor flow Im not sure what their native models look like, I did try Tensorflow but the export to onnx was incomplete to the gPR model could not be output to compare. It depends whether you want to support regression and gaussian regression here rather than more neural net oriented functions. Linear regression is quite simple in Eigen but the scaling pipeline and onehot encoding would be handy to have in a library to save having to program that part yourself. Ive been trying to consolidate to use just onnx as its cross platform but the perfmance is a lot slower than more specialised alternatives such as this. scikit has a function to convert a pipeline to onnx, but Im sure extracting the weights would be possible by stuying the implementation, or alternativly extracting then from onnx. Onnx also has models for neural nets etc Im not sure if you support Onnx in any way? |
I think the random number generator you mentional was probably in refrence to generating a set of samples from the predictive distribution, this was not sometihng I was thinking of, just pure inference using the trainded model parameters. So the implementation would be a set of coefficients from a trained kernel along with cdist and matrix multiplications etc |
Im currently using models from skikit that have been exported to Onnx but I was curious if I had a model for GPR in either PyTorch or TensorFlow would it be usable from RTNeural?
I think the TensorFlow GPR is in the probability module: https://www.tensorflow.org/probability/examples/Gaussian_Process_Regression_In_TFP and also in GPFlow: https://www.gpflow.org
And in PyTorch its available in GPytorch https://docs.gpytorch.ai/en/stable/ and Pyro: https://pyro.ai/examples/gp.html
I was curious if any of these wouild work as I was planning on using RTNeural in the future and could drop the usage of onnx runtime if RTNeural could load GPR models.
The text was updated successfully, but these errors were encountered: