Domine o NumPy: Encontre o Índice do Maior Elemento com argmax()

Neste guia, você vai aprender a usar a função argmax() do NumPy para descobrir o índice do maior elemento em arrays.

O NumPy é uma biblioteca essencial para computação científica em Python. Ela oferece arrays n-dimensionais que superam as listas do Python em desempenho. Uma tarefa comum ao trabalhar com arrays NumPy é identificar o valor máximo. No entanto, em certas ocasiões, pode ser necessário saber a posição (índice) em que esse valor máximo se encontra.

A função argmax() facilita a localização do índice do valor máximo tanto em arrays unidimensionais quanto multidimensionais. Vamos explorar como ela funciona.

Descobrindo o Índice do Maior Elemento em um Array NumPy

Para acompanhar este tutorial, você precisa ter Python e NumPy instalados. Você pode começar a codificar usando um REPL Python ou um notebook Jupyter.

Primeiramente, importe o NumPy usando o alias comum np.

import numpy as np

Você pode utilizar a função max() do NumPy para obter o maior valor em um array (opcionalmente, ao longo de um eixo específico).

array_1 = np.array([1,5,7,2,10,9,8,4])
print(np.max(array_1))

# Saída
10

Nesse exemplo, np.max(array_1) retorna 10, que é o valor máximo correto.

Imagine que você precise saber o índice onde o valor máximo ocorre no array. Para isso, você poderia seguir estes dois passos:

  • Achar o elemento máximo.
  • Localizar o índice desse elemento.

Em array_1, o valor máximo 10 está no índice 4, seguindo a indexação que começa em zero. O primeiro elemento está no índice 0, o segundo no índice 1 e assim por diante.

Para encontrar o índice do máximo, você pode usar a função where() do NumPy. np.where(condição) retorna um array com todos os índices onde a condição é verdadeira.

Será necessário acessar o array e pegar o valor no primeiro índice. Para descobrir onde o máximo aparece, definimos a condição como array_1==10, lembrando que 10 é o máximo em array_1.

print(int(np.where(array_1==10)[0]))

# Saída
4

Usamos np.where() apenas com a condição, mas essa não é a forma mais recomendada de usar essa função.

📑 Nota: Função where() do NumPy:
np.where(condição, x, y) retorna:

  • Elementos de x quando a condição é verdadeira, e
  • Elementos de y quando a condição é falsa.

Assim, encadeando as funções np.max() e np.where(), podemos encontrar o maior elemento e seu índice.

Contudo, em vez desse processo de duas etapas, a função argmax() do NumPy fornece o índice do elemento máximo diretamente.

Sintaxe da Função argmax() do NumPy

A forma geral de usar a função argmax() do NumPy é:

np.argmax(array, axis, out)
# importamos o numpy como np

Na sintaxe acima:

  • array é qualquer array NumPy válido.
  • axis é um parâmetro opcional. Em arrays multidimensionais, permite encontrar o índice do máximo ao longo de um eixo específico.
  • out também é opcional. Você pode definir out para um array NumPy, onde o resultado da função argmax() será armazenado.

Nota: A partir da versão 1.22.0 do NumPy, existe um parâmetro adicional keepdims. Ao especificar o parâmetro axis, o array é reduzido ao longo desse eixo. Definir keepdims como True garante que a saída tenha a mesma forma do array de entrada.

Usando argmax() para Encontrar o Índice do Maior Elemento

#1. Vamos utilizar a função argmax() para encontrar o índice do maior elemento em array_1.

array_1 = np.array([1,5,7,2,10,9,8,4])
print(np.argmax(array_1))

# Saída
4

A função argmax() retorna 4, que está correto! ✅

#2. Se redefinirmos array_1 com 10 ocorrendo duas vezes, argmax() retornará o índice da primeira ocorrência.

array_1 = np.array([1,5,7,2,10,10,8,4])
print(np.argmax(array_1))

# Saída
4

Nos exemplos seguintes, usaremos os elementos de array_1 definidos no exemplo #1.

Usando argmax() em um Array 2D

Vamos transformar array_1 em um array bidimensional com duas linhas e quatro colunas.

array_2 = array_1.reshape(2,4)
print(array_2)

# Saída
[[ 1  5  7  2]
 [10  9  8  4]]

Em um array bidimensional, o eixo 0 representa as linhas e o eixo 1 as colunas. Arrays NumPy utilizam indexação baseada em zero. Assim, os índices das linhas e colunas de array_2 são:

Agora, vamos chamar argmax() em array_2.

print(np.argmax(array_2))

# Saída
4

Apesar de chamarmos argmax() no array bidimensional, ele ainda retorna 4. Isso é idêntico à saída para o array unidimensional array_1.

Por que isso ocorre? 🤔

Isso acontece porque não definimos nenhum valor para o parâmetro axis. Por padrão, quando axis não é definido, argmax() retorna o índice do maior elemento ao longo do array “achatado”.

O que é um array achatado? Se você tem um array N-dimensional de forma d1 x d2 x … x dN, onde d1, d2, até dN são os tamanhos nas N dimensões, o array achatado é um array unidimensional de tamanho d1 * d2 * … * dN.

Para ver a aparência do array achatado de array_2, use o método flatten():

array_2.flatten()

# Saída
array([ 1,  5,  7,  2, 10,  9,  8,  4])

Índice do Maior Elemento ao Longo das Linhas (axis=0)

Agora, vamos encontrar o índice do maior elemento ao longo das linhas (axis=0).

np.argmax(array_2,axis=0)

# Saída
array([1, 1, 1, 1])

Essa saída pode ser um pouco difícil de entender, mas vamos analisar como ela funciona.

Definimos o parâmetro axis como zero (axis=0), pois queremos o índice do maior elemento ao longo das linhas. Assim, argmax() retorna o número da linha onde o maior elemento se encontra – para cada coluna.

Vamos visualizar isso para melhor compreensão.

Do diagrama e da saída de argmax(), temos:

  • Na primeira coluna (índice 0), o maior valor 10 está na segunda linha, no índice 1.
  • Na segunda coluna (índice 1), o maior valor 9 está na segunda linha, no índice 1.
  • Na terceira e quarta colunas (índices 2 e 3), os maiores valores 8 e 4 também estão na segunda linha, no índice 1.

É por isso que a saída é ([1, 1, 1, 1]), porque o maior elemento ao longo das linhas está na segunda linha (para todas as colunas).

Índice do Maior Elemento ao Longo das Colunas (axis=1)

Agora, vamos usar argmax() para encontrar o índice do maior elemento ao longo das colunas.

Execute o código abaixo e observe a saída.

np.argmax(array_2,axis=1)
array([2, 0])

Consegue entender a saída?

Definimos axis=1 para calcular o índice do maior elemento ao longo das colunas.

A função argmax() retorna, para cada linha, o número da coluna onde o maior valor ocorre.

Aqui está uma explicação visual:

Do diagrama e da saída de argmax(), temos:

  • Na primeira linha (índice 0), o maior valor 7 está na terceira coluna, no índice 2.
  • Na segunda linha (índice 1), o maior valor 10 está na primeira coluna, no índice 0.

Esperamos que agora você entenda que array([2, 0]) significa isso.

Usando o Parâmetro Opcional out em argmax()

Você pode usar o parâmetro opcional out na função argmax() para armazenar a saída em um array NumPy.

Vamos criar um array de zeros para guardar a saída da chamada anterior de argmax() — para encontrar o índice do maior elemento ao longo das colunas (axis=1).

out_arr = np.zeros((2,))
print(out_arr)
[0. 0.]

Agora, vamos revisitar o exemplo de encontrar o índice do maior elemento ao longo das colunas (axis=1) e definir out como out_arr que criamos acima.

np.argmax(array_2,axis=1,out=out_arr)

O interpretador Python retorna um TypeError, pois out_arr foi inicializado como um array de floats por padrão.

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds)
     56     try:
---> 57         return bound(*args, **kwds)
     58     except TypeError:

TypeError: Cannot cast array data from dtype('float64') to dtype('int64') according to the rule 'safe'

Ao definir o parâmetro out para o array de saída, é fundamental garantir que o array tenha a forma e o tipo de dados corretos. Como os índices de array são sempre números inteiros, devemos definir o parâmetro dtype como int ao criar o array de saída.

out_arr = np.zeros((2,),dtype=int)
print(out_arr)

# Saída
[0 0]

Agora podemos chamar argmax() com os parâmetros axis e out, e desta vez, ele é executado sem erros.

np.argmax(array_2,axis=1,out=out_arr)

O resultado de argmax() agora pode ser acessado em out_arr.

print(out_arr)
# Saída
[2 0]

Conclusão

Esperamos que este guia tenha ajudado você a entender como usar a função argmax() do NumPy. Você pode testar os exemplos de código em um notebook Jupyter.

Vamos revisar o que aprendemos.

  • A função argmax() do NumPy retorna o índice do maior elemento em um array. Se o maior elemento aparecer mais de uma vez, np.argmax(a) retorna o índice da primeira ocorrência.
  • Em arrays multidimensionais, você pode usar o parâmetro opcional axis para obter o índice do maior elemento ao longo de um eixo específico. Por exemplo, em um array bidimensional: axis=0 e axis=1 para obter o índice do maior elemento ao longo de linhas e colunas, respectivamente.
  • Para guardar o resultado em outro array, use o parâmetro out. No entanto, o array de saída precisa ter o formato e tipo de dados corretos.

Em seguida, confira o guia detalhado sobre conjuntos (sets) em Python.