!eWOErHSaiddIbsUNsJ:nixos.org

NixOS CUDA

211 Members
CUDA packages maintenance and support in nixpkgs | https://github.com/orgs/NixOS/projects/27/ | https://nixos.org/manual/nixpkgs/unstable/#cuda42 Servers

You have reached the beginning of time (for this room).


SenderMessageTime
11 Jul 2024
@ss:someonex.netSomeoneSerge (utc+3)ROCm stuff cache-missing again12:39:44
@ss:someonex.netSomeoneSerge (utc+3) https://hydra.nixos.org/job/nixpkgs/trunk/rocmPackages.rocsolver.x86_64-linux yeah getting a timeout locally as well 12:40:02
@yorik.sar:matrix.orgyorik.sar joined the room.15:05:02
@ss:someonex.netSomeoneSerge (utc+3) Madoura I'm trying to build rocsolver again but I already suspect it's going to time out another time 16:30:42
@ss:someonex.netSomeoneSerge (utc+3) Madoura https://discourse.nixos.org/t/testing-gpu-compute-on-amd-apu-nixos/47060/2 this is falling apart šŸ˜¹ 23:16:45
12 Jul 2024
@valconius:matrix.org@valconius:matrix.org left the room.01:16:15
13 Jul 2024
@mcwitt:matrix.orgmcwitt

This might just be me being dumb, but am surprised that I'm unable to build jax with doCheck = false (use case is I want to override jaxlib = jaxlibWithCuda and don't want to run the tests). Repro:

nix build --impure --expr 'let nixpkgs = builtins.getFlake("github:nixos/nixpkgs/nixpkgs-unstable"); pkgs = import nixpkgs { system = "x86_64-linux"; config.allowUnfree = true; }; in pkgs.python3Packages.jax.overridePythonAttrs { doCheck = false; }'

fails with ModuleNotFoundError: jax requires jaxlib to be installed

01:58:54
@mcwitt:matrix.orgmcwitt

Something else that's puzzling me: overriding with jaxlib = jaxlibWithCuda doesn't seem to work for numpyro (which has a pretty simple derivation):

Sanity check: a Python env with just jax and jaxlibWithCuda is GPU-enabled, as expected:

nix shell --refresh --impure --expr 'let nixpkgs = builtins.getFlake "github:nixos/nixpkgs/nixpkgs-unstable"; pkgs = import nixpkgs { config.allowUnfree = true; }; in (pkgs.python3.withPackages (ps: [ ps.jax ps.jaxlibWithCuda ]))' --command python -c "import jax; print(jax.devices())"

yields [cuda(id=0), cuda(id=1)].

But when numpyro is overridden to use jaxlibWithCuda, for some reason the propagated jaxlib is still the CPU version:

nix shell --refresh --impure --expr 'let nixpkgs = builtins.getFlake "github:nixos/nixpkgs/nixpkgs-unstable"; pkgs = import nixpkgs { config.allowUnfree = true; }; in (pkgs.python3.withPackages (ps: [ ((ps.numpyro.overridePythonAttrs (_: { doCheck = false; })).override { jaxlib = ps.jaxlibWithCuda; }) ]))' --command python -c "import jax; print(jax.devices())"

yields [CpuDevice(id=0)]. (Furthermore, if we try to add jaxlibWithCuda to the withPackages call, we get a collision error, so clearly something is propagating the CPU jaxlib šŸ¤”

Has anyone seen this, or have a better way to use the GPU-enabled jaxlib as a dependency?

04:30:05
@mcwitt:matrix.orgmcwitt *

Something else that's puzzling me: overriding with jaxlib = jaxlibWithCuda doesn't seem to work for numpyro (which has a pretty simple derivation):

Sanity check: a Python env with just jax and jaxlibWithCuda is GPU-enabled, as expected:

nix shell --refresh --impure --expr 'let nixpkgs = builtins.getFlake "github:nixos/nixpkgs/nixpkgs-unstable"; pkgs = import nixpkgs { config.allowUnfree = true; }; in (pkgs.python3.withPackages (ps: [ ps.jax ps.jaxlibWithCuda ]))' --command python -c "import jax; print(jax.devices())"

yields [cuda(id=0), cuda(id=1)].

But when numpyro is overridden to use jaxlibWithCuda, for some reason the propagated jaxlib is still the CPU version:

nix shell --refresh --impure --expr 'let nixpkgs = builtins.getFlake "github:nixos/nixpkgs/nixpkgs-unstable"; pkgs = import nixpkgs { config.allowUnfree = true; }; in (pkgs.python3.withPackages (ps: [ ((ps.numpyro.overridePythonAttrs (_: { doCheck = false; })).override { jaxlib = ps.jaxlibWithCuda; }) ]))' --command python -c "import jax; print(jax.devices())"

yields [CpuDevice(id=0)]. (Furthermore, if we try to add jaxlibWithCuda to the withPackages call, we get a collision error, so clearly something is propagating the CPU jaxlib šŸ¤”)

Has anyone seen this, or have a better way to use the GPU-enabled jaxlib as a dependency?

04:30:27
@mcwitt:matrix.orgmcwittAck, the second one was just me being dumb. I'd mixed up the proper ordering of `override` and `overridePythonAttrs` šŸ¤¦ sorry for the noise05:03:40
@ss:someonex.netSomeoneSerge (utc+3)
In reply to @mcwitt:matrix.org

Something else that's puzzling me: overriding with jaxlib = jaxlibWithCuda doesn't seem to work for numpyro (which has a pretty simple derivation):

Sanity check: a Python env with just jax and jaxlibWithCuda is GPU-enabled, as expected:

nix shell --refresh --impure --expr 'let nixpkgs = builtins.getFlake "github:nixos/nixpkgs/nixpkgs-unstable"; pkgs = import nixpkgs { config.allowUnfree = true; }; in (pkgs.python3.withPackages (ps: [ ps.jax ps.jaxlibWithCuda ]))' --command python -c "import jax; print(jax.devices())"

yields [cuda(id=0), cuda(id=1)].

But when numpyro is overridden to use jaxlibWithCuda, for some reason the propagated jaxlib is still the CPU version:

nix shell --refresh --impure --expr 'let nixpkgs = builtins.getFlake "github:nixos/nixpkgs/nixpkgs-unstable"; pkgs = import nixpkgs { config.allowUnfree = true; }; in (pkgs.python3.withPackages (ps: [ ((ps.numpyro.overridePythonAttrs (_: { doCheck = false; })).override { jaxlib = ps.jaxlibWithCuda; }) ]))' --command python -c "import jax; print(jax.devices())"

yields [CpuDevice(id=0)]. (Furthermore, if we try to add jaxlibWithCuda to the withPackages call, we get a collision error, so clearly something is propagating the CPU jaxlib šŸ¤”)

Has anyone seen this, or have a better way to use the GPU-enabled jaxlib as a dependency?

Wait why is numpyro propagating jaxlib, I thought we had a convention not to propagate jaxlib
08:10:45
@ocharles:matrix.orgocharles set a profile picture.10:21:04
15 Jul 2024
@oak:universumi.fioak changed their profile picture.03:16:33
@connorbaker:matrix.orgconnor (he/him) (UTC-7)Were there any changes to Python packages over the last month or so that may have caused breakages? Like default version or anything? I rebased a PR and Iā€™m seeing more build failures than I was previously21:25:25
@ss:someonex.netSomeoneSerge (utc+3)
In reply to @connorbaker:matrix.org
Were there any changes to Python packages over the last month or so that may have caused breakages? Like default version or anything? I rebased a PR and Iā€™m seeing more build failures than I was previously
Like 3.12?
21:25:53

Show newer messages


Back to Room ListRoom Version: 9