O que é o Google JAX? Tudo o que você precisa saber

O Google JAX ou Just After Execution é um framework desenvolvido pelo Google para acelerar tarefas de aprendizado de máquina.

Você pode considerá-lo uma biblioteca para Python, que ajuda na execução mais rápida de tarefas, computação científica, transformações de funções, aprendizado profundo, redes neurais e muito mais.

Sobre o Google JAX

O pacote de computação mais fundamental em Python é o pacote NumPy, que possui todas as funções, como agregações, operações vetoriais, álgebra linear, manipulações de matrizes e matrizes n-dimensionais e muitas outras funções avançadas.

E se pudéssemos acelerar ainda mais os cálculos realizados usando o NumPy – principalmente para grandes conjuntos de dados?

Temos algo que pode funcionar igualmente bem em diferentes tipos de processadores, como GPU ou TPU, sem nenhuma alteração de código?

Que tal se o sistema pudesse realizar transformações de funções compostas automaticamente e com mais eficiência?

O Google JAX é uma biblioteca (ou framework, como diz a Wikipedia) que faz exatamente isso e talvez muito mais. Ele foi desenvolvido para otimizar o desempenho e executar tarefas de aprendizado de máquina (ML) e aprendizado profundo com eficiência. O Google JAX oferece os seguintes recursos de transformação que o tornam exclusivo de outras bibliotecas de ML e ajudam na computação científica avançada para aprendizado profundo e redes neurais:

  • Diferenciação automática
  • Vetorização automática
  • Paralelização automática
  • Compilação just-in-time (JIT)

Recursos exclusivos do Google JAX

Todas as transformações usam XLA (Álgebra Linear Acelerada) para maior desempenho e otimização de memória. O XLA é um mecanismo de compilador de otimização específico de domínio que executa álgebra linear e acelera os modelos do TensorFlow. Usar XLA em cima de seu código Python não requer alterações significativas no código!

Vamos explorar em detalhes cada um desses recursos.

Recursos do Google JAX

O Google JAX vem com importantes funções de transformação que podem ser compostas para melhorar o desempenho e realizar tarefas de aprendizado profundo com mais eficiência. Por exemplo, diferenciação automática para obter o gradiente de uma função e encontrar derivadas de qualquer ordem. Da mesma forma, paralelização automática e JIT para executar várias tarefas paralelamente. Essas transformações são fundamentais para aplicações como robótica, jogos e até pesquisa.

Uma função de transformação combinável é uma função pura que transforma um conjunto de dados em outra forma. Eles são chamados de componíveis porque são autocontidos (ou seja, essas funções não têm dependências com o resto do programa) e não têm estado (ou seja, a mesma entrada sempre resultará na mesma saída).

Y(x) = T: (f(x))

Na equação acima, f(x) é a função original na qual uma transformação é aplicada. Y(x) é a função resultante após a transformação ser aplicada.

  Ryzen 4000: Seu próximo laptop para jogos será AMD em vez de Intel?

Por exemplo, se você tem uma função chamada ‘total_bill_amt’ e deseja o resultado como uma transformação de função, você pode simplesmente usar a transformação que deseja, digamos gradiente (grad):

grad_total_bill = grad(total_bill_amt)

Ao transformar funções numéricas usando funções como grad(), podemos obter facilmente suas derivadas de ordem superior, que podemos usar extensivamente em algoritmos de otimização de aprendizado profundo, como gradiente descendente, tornando os algoritmos mais rápidos e eficientes. Da mesma forma, usando jit(), podemos compilar programas Python just-in-time (preguiçosamente).

#1. Diferenciação automática

O Python usa a função autograd para diferenciar automaticamente o NumPy e o código nativo do Python. JAX usa uma versão modificada de autograd (ou seja, grad) e combina XLA (Accelerated Linear Algebra) para realizar diferenciação automática e encontrar derivados de qualquer ordem para GPU (Graphic Processing Units) e TPU (Tensor Processing Units).]

Nota rápida sobre TPU, GPU e CPU: CPU ou Unidade Central de Processamento gerencia todas as operações no computador. A GPU é um processador adicional que aumenta o poder de computação e executa operações de ponta. A TPU é uma unidade poderosa desenvolvida especificamente para cargas de trabalho complexas e pesadas, como IA e algoritmos de aprendizado profundo.

Na mesma linha da função autograd, que pode diferenciar por meio de loops, recursões, ramificações e assim por diante, JAX usa a função grad() para gradientes de modo reverso (backpropagation). Além disso, podemos diferenciar uma função de qualquer ordem usando grad:

grad(grad(grad(sen θ)))) (1.0)

Diferenciação automática de ordem superior

Como mencionamos antes, grad é bastante útil para encontrar as derivadas parciais de uma função. Podemos usar uma derivada parcial para calcular o gradiente descendente de uma função de custo em relação aos parâmetros da rede neural em aprendizado profundo para minimizar as perdas.

Calculando a derivada parcial

Suponha que uma função tenha múltiplas variáveis, x, y e z. Encontrar a derivada de uma variável mantendo as outras variáveis ​​constantes é chamado de derivada parcial. Suponha que temos uma função,

f(x,y,z) = x + 2y + z2

Exemplo para mostrar derivada parcial

A derivada parcial de x será ∂f/∂x, que nos diz como uma função muda para uma variável quando outras são constantes. Se fizermos isso manualmente, devemos escrever um programa para diferenciar, aplicá-lo para cada variável e então calcular o gradiente descendente. Isso se tornaria um assunto complexo e demorado para múltiplas variáveis.

A diferenciação automática divide a função em um conjunto de operações elementares, como +, -, *, / ou sin, cos, tan, exp, etc., e então aplica a regra da cadeia para calcular a derivada. Podemos fazer isso no modo de avanço e reverso.

Não é isso! Todos esses cálculos acontecem tão rápido (bem, pense em um milhão de cálculos semelhantes aos acima e no tempo que pode levar!). XLA cuida da velocidade e desempenho.

  Guerra Civil na Netflix de Qualquer Lugar

#2. Álgebra Linear Acelerada

Vamos pegar a equação anterior. Sem XLA, a computação levará três (ou mais) kernels, onde cada kernel executará uma tarefa menor. Por exemplo,

Kernel k1 –> x * 2y (multiplicação)

k2 –> x * 2y + z (adição)

k3 –> Redução

Se a mesma tarefa for executada pelo XLA, um único kernel cuida de todas as operações intermediárias fundindo-as. Os resultados intermediários das operações elementares são transmitidos em vez de serem armazenados na memória, economizando memória e aumentando a velocidade.

#3. Compilação na hora

JAX usa internamente o compilador XLA para aumentar a velocidade de execução. O XLA pode aumentar a velocidade da CPU, GPU e TPU. Tudo isso é possível usando a execução do código JIT. Para usar isso, podemos usar jit via importação:

from jax import jit
def my_function(x):
	…………some lines of code
my_function_jit = jit(my_function)

Outra maneira é decorar o jit sobre a definição da função:

@jit
def my_function(x):
	…………some lines of code

Esse código é muito mais rápido porque a transformação retornará a versão compilada do código para o chamador em vez de usar o interpretador Python. Isso é particularmente útil para entradas vetoriais, como arrays e matrizes.

O mesmo vale para todas as funções python existentes. Por exemplo, funções do pacote NumPy. Nesse caso, devemos importar jax.numpy como jnp em vez de NumPy:

import jax
import jax.numpy as jnp

x = jnp.array([[1,2,3,4], [5,6,7,8]])

Depois de fazer isso, o objeto principal da matriz JAX chamado DeviceArray substitui a matriz NumPy padrão. DeviceArray é preguiçoso – os valores são mantidos no acelerador até serem necessários. Isso também significa que o programa JAX não espera que os resultados retornem ao programa de chamada (Python), seguindo um despacho assíncrono.

#4. Vetorização automática (vmap)

Em um mundo típico de aprendizado de máquina, temos conjuntos de dados com um milhão ou mais de pontos de dados. Muito provavelmente, realizaríamos alguns cálculos ou manipulações em cada um ou na maioria desses pontos de dados – o que é uma tarefa que consome muito tempo e memória! Por exemplo, se você deseja encontrar o quadrado de cada um dos pontos de dados no conjunto de dados, a primeira coisa que você pensa é criar um loop e pegar o quadrado um por um – argh!

Se criarmos esses pontos como vetores, poderíamos fazer todos os quadrados de uma só vez realizando manipulações de vetores ou matrizes nos pontos de dados com nosso NumPy favorito. E se o seu programa pudesse fazer isso automaticamente – você pode pedir mais alguma coisa? Isso é exatamente o que JAX faz! Ele pode vetorizar automaticamente todos os seus pontos de dados para que você possa executar facilmente qualquer operação neles – tornando seus algoritmos muito mais rápidos e eficientes.

JAX usa a função vmap para vetorização automática. Considere a seguinte matriz:

x = jnp.array([1,2,3,4,5,6,7,8,9,10])
y = jnp.square(x)

Fazendo apenas o acima, o método quadrado será executado para cada ponto na matriz. Mas se você fizer o seguinte:

vmap(jnp.square(x))

O método quadrado será executado apenas uma vez porque os pontos de dados agora são vetorizados automaticamente usando o método vmap antes de executar a função, e o loop é empurrado para o nível elementar de operação – resultando em uma multiplicação de matriz em vez de uma multiplicação escalar, proporcionando melhor desempenho .

  Como definir texto de dica de ferramenta de hiperlink personalizado no MS Excel

#5. Programação SPMD (pmap)

A programação SPMD – ou Programa Único de Dados Múltiplos é essencial em contextos de aprendizado profundo – você geralmente aplica as mesmas funções em diferentes conjuntos de dados que residem em várias GPUs ou TPUs. O JAX possui uma função chamada pump, que permite a programação paralela em várias GPUs ou em qualquer acelerador. Assim como o JIT, os programas usando pmap serão compilados pelo XLA e executados simultaneamente nos sistemas. Essa paralelização automática funciona para cálculos diretos e reversos.

Como funciona o pmap

Também podemos aplicar várias transformações de uma só vez em qualquer ordem em qualquer função como:

pmap(vmap(jit(grad(f(x))))))

Múltiplas transformações compostas

Limitações do Google JAX

Os desenvolvedores do Google JAX pensaram bem em acelerar os algoritmos de aprendizado profundo ao introduzir todas essas transformações incríveis. As funções e pacotes de computação científica estão nas linhas do NumPy, então você não precisa se preocupar com a curva de aprendizado. No entanto, JAX tem as seguintes limitações:

  • O Google JAX ainda está nos estágios iniciais de desenvolvimento e, embora seu objetivo principal seja a otimização de desempenho, ele não oferece muitos benefícios para a computação da CPU. O NumPy parece ter um desempenho melhor e o uso do JAX pode apenas aumentar a sobrecarga.
  • O JAX ainda está em fase de pesquisa ou estágio inicial e precisa de mais ajustes para atingir os padrões de infraestrutura de frameworks como o TensorFlow, que são mais estabelecidos e têm mais modelos pré-definidos, projetos de código aberto e material de aprendizado.
  • A partir de agora, o JAX não suporta o sistema operacional Windows – você precisaria de uma máquina virtual para fazê-lo funcionar.
  • O JAX funciona apenas em funções puras – aquelas que não têm efeitos colaterais. Para funções com efeitos colaterais, JAX pode não ser uma boa opção.

Como instalar o JAX em seu ambiente Python

Se você tiver uma configuração python em seu sistema e quiser executar o JAX em sua máquina local (CPU), use os seguintes comandos:

pip install --upgrade pip
pip install --upgrade "jax[cpu]"

Se você deseja executar o Google JAX em uma GPU ou TPU, siga as instruções fornecidas em GitHub JAX página. Para configurar o Python, visite o downloads oficiais python página.

Conclusão

O Google JAX é ótimo para escrever algoritmos eficientes de aprendizado profundo, robótica e pesquisa. Apesar das limitações, é usado extensivamente com outros frameworks como Haiku, Flax e muitos outros. Você poderá apreciar o que o JAX faz ao executar programas e ver as diferenças de tempo na execução de código com e sem JAX. Você pode começar lendo o documentação oficial do Google JAXque é bastante abrangente.