Google JAX

Google JAX
Tipusbiblioteca Python Modifica el valor a Wikidata
Versió estable
0.4.24 (6 febrer 2024) Modifica el valor a Wikidata
LlicènciaLlicència Apache, versió 2.0 Modifica el valor a Wikidata
Equip
Desenvolupador(s)Peter Hawkins (en) Tradueix, Matthew Johnson (en) Tradueix i Jacob VanderPlas (en) Tradueix Modifica el valor a Wikidata
Fonts de codi 
Codi fontCodi font Modifica el valor a Wikidata

Més informació
Lloc webjax.readthedocs.io… Modifica el valor a Wikidata

Google JAX és un marc d'aprenentatge automàtic per transformar funcions numèriques.[1][2] Es descriu com reunir una versió modificada d'autograd (obtenció automàtica de la funció de gradient mitjançant la diferenciació d'una funció) i XLA de TensorFlow (àlgebra lineal accelerada). Està dissenyat per seguir l'estructura i el flux de treball de NumPy tan de prop com sigui possible i funciona amb diversos marcs existents com TensorFlow i PyTorch.[3][4] Les funcions principals de JAX són:

  1. grau: diferenciació automàtica
  2. jit: compilació
  3. vmap: vectorització automàtica
  4. pmap: programació SPMD

Funció grau

El codi següent mostra la diferenciació automàtica de la funció de graduació .

# imports
from jax import grad
import jax.numpy as jnp

# define the logistic function
def logistic(x): 
  return jnp.exp(x) / (jnp.exp(x) + 1)

# obtain the gradient function of the logistic function
grad_logistic = grad(logistic)

# evaluate the gradient of the logistic function at x = 1 
grad_log_out = grad_logistic(1.0)  
print(grad_log_out)

Funcio jit

El codi següent mostra l'optimització de la funció jit mitjançant la fusió.

# imports
from jax import jit
import jax.numpy as jnp

# define the cube function
def cube(x):
  return x * x * x

# generate data
x = jnp.ones((10000, 10000))

# create the jit version of the cube function
jit_cube = jit(cube)

# apply the cube and jit_cube functions to the same data for speed comparison
cube(x)
jit_cube(x)

Funció vmap

El codi següent mostra la vectorització de la funció vmap.

# imports
from functools import partial
from jax import vmap
import jax.numpy as jnp

# define function
def grads(self, inputs):
  in_grad_partial = partial(self._net_grads, self._net_params)
  grad_vmap = vmap(in_grad_partial)
  rich_grads = grad_vmap(inputs)
  flat_grads = np.asarray(self._flatten_batch(rich_grads))
  assert flat_grads.ndim == 2 and flat_grads.shape[0] == inputs.shape[0]
  return flat_grads

Funció pmap

El codi següent mostra la paral·lelització de la funció pmap per a la multiplicació de matrius.

# import pmap and random from JAX; import JAX NumPy
from jax import pmap, random
import jax.numpy as jnp

# generate 2 random matrices of dimensions 5000 x 6000, one per device
random_keys = random.split(random.PRNGKey(0), 2)
matrices = pmap(lambda key: random.normal(key, (5000, 6000)))(random_keys)

# without data transfer, in parallel, perform a local matrix multiplication on each CPU/GPU
outputs = pmap(lambda x: jnp.dot(x, x.T))(matrices)

# without data transfer, in parallel, obtain the mean for both matrices on each CPU/GPU separately
means = pmap(jnp.mean)(outputs)
print(means)

Biblioteques que utilitzen JAX

Diverses biblioteques de Python utilitzen JAX com a backend, incloent:

Referències

  1. Frostig, Roy; Johnson, Matthew James; Leary, Chris MLsys, 02-02-2018, pàg. 1–3.
  2. «Using JAX to accelerate our research» (en anglès). www.deepmind.com. Arxivat de l'original el 2022-06-18. [Consulta: 18 juny 2022].
  3. Lynley, Matthew. «Google is quietly replacing the backbone of its AI product strategy after its last big push for dominance got overshadowed by Meta» (en anglès americà). Business Insider. Arxivat de l'original el 2022-06-21. [Consulta: 21 juny 2022].
  4. «Why is Google's JAX so popular?» (en anglès americà). Analytics India Magazine, 25-04-2022. Arxivat de l'original el 2022-06-18. [Consulta: 18 juny 2022].