We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
While running this in Google Colab I get the following error: I am using the pro version of Google Collab.
XlaRuntimeError Traceback (most recent call last) in <cell line: 9>() 9 for i in trange(max(n_predictions // jax.device_count(), 1)): 10 # get a new key ---> 11 key, subkey = jax.random.split(key) 12 # generate images 13 encoded_images = p_generate(
10 frames [... skipping hidden 2 frame]
/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py in _execute_compiled(name, compiled, input_handler, output_buffer_counts, result_handler, has_unordered_effects, ordered_effects, kept_var_idx, has_host_callbacks, *args) 893 runtime_token = None 894 else: --> 895 out_flat = compiled.execute(in_flat) 896 check_special(name, out_flat) 897 out_bufs = unflatten(out_flat, output_buffer_counts)
XlaRuntimeError: INTERNAL: CustomCall failed: jaxlib/gpu/prng_kernels.cc:33: operation gpuGetLastError() failed: cudaGetErrorString symbol not found
The text was updated successfully, but these errors were encountered:
This happens when running this cell
from flax.training.common_utils import shard_prng_key import numpy as np from PIL import Image from tqdm.notebook import trange
print(f"Prompts: {prompts}\n")
images = [] for i in trange(max(n_predictions // jax.device_count(), 1)): # get a new key key, subkey = jax.random.split(key) # generate images encoded_images = p_generate( tokenized_prompt, shard_prng_key(subkey), params, gen_top_k, gen_top_p, temperature, cond_scale, ) # remove BOS encoded_images = encoded_images.sequences[..., 1:] # decode images decoded_images = p_decode(encoded_images, vqgan_params) decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3)) for decoded_img in decoded_images: img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8)) images.append(img) display(img) print()
Sorry, something went wrong.
Hello! Did you solve this problem?
No branches or pull requests
While running this in Google Colab I get the following error: I am using the pro version of Google Collab.
XlaRuntimeError Traceback (most recent call last)
in <cell line: 9>()
9 for i in trange(max(n_predictions // jax.device_count(), 1)):
10 # get a new key
---> 11 key, subkey = jax.random.split(key)
12 # generate images
13 encoded_images = p_generate(
10 frames
[... skipping hidden 2 frame]
/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py in _execute_compiled(name, compiled, input_handler, output_buffer_counts, result_handler, has_unordered_effects, ordered_effects, kept_var_idx, has_host_callbacks, *args)
893 runtime_token = None
894 else:
--> 895 out_flat = compiled.execute(in_flat)
896 check_special(name, out_flat)
897 out_bufs = unflatten(out_flat, output_buffer_counts)
XlaRuntimeError: INTERNAL: CustomCall failed: jaxlib/gpu/prng_kernels.cc:33: operation gpuGetLastError() failed: cudaGetErrorString symbol not found
The text was updated successfully, but these errors were encountered: