You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
~/benchmark > git clone https://github.com/openxla/iree-comparative-benchmark
Cloning into 'iree-comparative-benchmark'...
remote: Enumerating objects: 2003, done.
remote: Counting objects: 100% (1051/1051), done.
remote: Compressing objects: 100% (440/440), done.
remote: Total 2003 (delta 770), reused 683 (delta 585), pack-reused 952
Receiving objects: 100% (2003/2003), 1.01 MiB | 3.58 MiB/s, done.
Resolving deltas: 100% (1148/1148), done.
~/benchmark > cd iree-comparative-benchmark/comparative_benchmark/jax
~/benchmark/iree-comparative-benchmark/comparative_benchmark/jax > ./setup_venv.sh
...
~/benchmark/iree-comparative-benchmark/comparative_benchmark/jax > source jax-benchmarks.venv/bin/activate
(jax-benchmarks.venv) ~/benchmark/iree-comparative-benchmark/comparative_benchmark/jax > ./run_benchmarks.py -o test -device host-cpu -name models/BERT_BASE_FP32_JAX_I32_SEQLEN32/inputs/INPUT_DATA_MODEL_DEFAULT
--- models/BERT_BASE_FP32_JAX_I32_SEQLEN32/inputs/INPUT_DATA_MODEL_DEFAULT ---
/usr/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
self.pid = os.fork()
Some weights of FlaxBertModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: {('pooler', 'dense', 'bias'), ('pooler', 'dense', 'kernel')}
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
(jax-benchmarks.venv) ~/benchmark/iree-comparative-benchmark/comparative_benchmark/jax > cat test
{"benchmarks": [{"definition": {"benchmark_name": "models/BERT_BASE_FP32_JAX_I32_SEQLEN32/inputs/INPUT_DATA_MODEL_DEFAULT", "framework": "ModelFrameworkType.JAX", "data_type": "fp32", "batch_size": 1, "compiler": "xla", "device": "host-cpu", "tags": ["transformer-encoder", "bert", "seqlen-32"]}, "metrics": {"framework_level": {"error": "['Output 0 exceeds tolerance. Max diff: 8.474491119384766, atol: 0.5, rtol: 0']"}}}]}
I tried both -device host-cpu and -device host-gpu, and over 20 different models. I got something that looks like valid timings from models/SD_PIPELINE_FP16_JAX_64XI32_BATCH1/inputs/INPUT_DATA_MODEL_DEFAULT, all others return "Output 0 exceeds tolerance."
The text was updated successfully, but these errors were encountered:
I tried both -device host-cpu and -device host-gpu, and over 20 different models. I got something that looks like valid timings from models/SD_PIPELINE_FP16_JAX_64XI32_BATCH1/inputs/INPUT_DATA_MODEL_DEFAULT, all others return "Output 0 exceeds tolerance."
The text was updated successfully, but these errors were encountered: