Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Output 0 exceeds tolerance? #174

Open
ekuznetsov139 opened this issue Apr 18, 2024 · 0 comments
Open

Output 0 exceeds tolerance? #174

ekuznetsov139 opened this issue Apr 18, 2024 · 0 comments

Comments

@ekuznetsov139
Copy link

ekuznetsov139 commented Apr 18, 2024

 ~/benchmark > git clone https://github.com/openxla/iree-comparative-benchmark                                                                                     
Cloning into 'iree-comparative-benchmark'...                                                                                                                                                                                                                                                                                                            
remote: Enumerating objects: 2003, done.                                                                                                                                                                                                                                                                                                                
remote: Counting objects: 100% (1051/1051), done.                                                                                                                           
remote: Compressing objects: 100% (440/440), done.                   
remote: Total 2003 (delta 770), reused 683 (delta 585), pack-reused 952                                                                                                                                                                                                                                                                                 
Receiving objects: 100% (2003/2003), 1.01 MiB | 3.58 MiB/s, done.                                                                                                                                                                                                                                                                                       
Resolving deltas: 100% (1148/1148), done.                                                                                                                                                                                                                                                                                                               
 ~/benchmark > cd iree-comparative-benchmark/comparative_benchmark/jax                                                                                                                                                                                                                                                                         
 ~/benchmark/iree-comparative-benchmark/comparative_benchmark/jax > ./setup_venv.sh      
...
 ~/benchmark/iree-comparative-benchmark/comparative_benchmark/jax > source jax-benchmarks.venv/bin/activate                                                        
(jax-benchmarks.venv)  ~/benchmark/iree-comparative-benchmark/comparative_benchmark/jax > ./run_benchmarks.py  -o test -device host-cpu -name models/BERT_BASE_FP32_JAX_I32_SEQLEN32/inputs/INPUT_DATA_MODEL_DEFAULT


--- models/BERT_BASE_FP32_JAX_I32_SEQLEN32/inputs/INPUT_DATA_MODEL_DEFAULT ---
/usr/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  self.pid = os.fork()
Some weights of FlaxBertModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: {('pooler', 'dense', 'bias'), ('pooler', 'dense', 'kernel')}
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
(jax-benchmarks.venv) ~/benchmark/iree-comparative-benchmark/comparative_benchmark/jax > cat test
{"benchmarks": [{"definition": {"benchmark_name": "models/BERT_BASE_FP32_JAX_I32_SEQLEN32/inputs/INPUT_DATA_MODEL_DEFAULT", "framework": "ModelFrameworkType.JAX", "data_type": "fp32", "batch_size": 1, "compiler": "xla", "device": "host-cpu", "tags": ["transformer-encoder", "bert", "seqlen-32"]}, "metrics": {"framework_level": {"error": "['Output 0 exceeds tolerance. Max diff: 8.474491119384766, atol: 0.5, rtol: 0']"}}}]}

I tried both -device host-cpu and -device host-gpu, and over 20 different models. I got something that looks like valid timings from models/SD_PIPELINE_FP16_JAX_64XI32_BATCH1/inputs/INPUT_DATA_MODEL_DEFAULT, all others return "Output 0 exceeds tolerance."

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant