You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am using DALLE_MODEL = "dalle-mini/dalle-mini/mini-1:latest" to save some ram.
I think its because of jax versioning. Can anyone share a jax version which works?
Prompts: ['sunset over a lake in the mountains', 'the Eiffel tower landing on the moon']
0%
0/2 [00:03<?, ?it/s]
/usr/local/lib/python3.8/dist-packages/jax/_src/ops/scatter.py:87: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=float16 to dtype=float32. In future JAX releases this will result in an error.
warnings.warn("scatter inputs have incompatible types: cannot safely cast "
2023-02-01 07:02:31.314672: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:219] failed to create cublas handle: the library was not initialized
2023-02-01 07:02:31.314705: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:222] Failure to initialize cublas may be due to OOM (cublas needs some free memory when you initialize it, and your deep-learning framework may have preallocated more than its fair share), or may be because this binary was not built with support for the GPU in your machine.
2023-02-01 07:02:31.315671: E external/org_tensorflow/tensorflow/compiler/xla/status_macros.cc:57] INTERNAL: RET_CHECK failure (external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc:371) stream->parent()->GetBlasGemmAlgorithms(stream, &algorithms)
*** Begin stack trace ***
PyCFunction_Call
_PyObject_MakeTpCall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
PyObject_Call
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
PyObject_Call
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
PyObject_Call
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
PyObject_Call
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
PyObject_Call
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
PyObject_Call
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
PyObject_Call
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
PyEval_EvalCode
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
PyObject_Call
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyObject_MakeTpCall
PyVectorcall_Call
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
PyEval_EvalCode
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
*** End stack trace ***
---------------------------------------------------------------------------
XlaRuntimeError Traceback (most recent call last)
Cell In[16], line 13
11 key, subkey = jax.random.split(key)
12 # generate images
---> 13 encoded_images = p_generate(
14 tokenized_prompt,
15 shard_prng_key(subkey),
16 params,
17 gen_top_k,
18 gen_top_p,
19 temperature,
20 cond_scale,
21 )
22 # remove BOS
23 encoded_images = encoded_images.sequences[..., 1:]
[... skipping hidden 11 frame]
File /usr/local/lib/python3.8/dist-packages/jax/_src/dispatch.py:1024, in backend_compile(backend, built_c, options, host_callbacks)
1019 return backend.compile(built_c, compile_options=options,
1020 host_callbacks=host_callbacks)
1021 # Some backends don't have `host_callbacks` option yet
1022 # TODO(sharadmv): remove this fallback when all backends allow `compile`
1023 # to take in `host_callbacks`
-> 1024 return backend.compile(built_c, compile_options=options)
XlaRuntimeError: INTERNAL: RET_CHECK failure (external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc:371) stream->parent()->GetBlasGemmAlgorithms(stream, &algorithms)
The text was updated successfully, but these errors were encountered:
I am using
DALLE_MODEL = "dalle-mini/dalle-mini/mini-1:latest"
to save some ram.I think its because of jax versioning. Can anyone share a jax version which works?
The text was updated successfully, but these errors were encountered: