Neste guia, você vai aprender a usar a função argmax()
do NumPy para descobrir o índice do maior elemento em arrays.
O NumPy é uma biblioteca essencial para computação científica em Python. Ela oferece arrays n-dimensionais que superam as listas do Python em desempenho. Uma tarefa comum ao trabalhar com arrays NumPy é identificar o valor máximo. No entanto, em certas ocasiões, pode ser necessário saber a posição (índice) em que esse valor máximo se encontra.
A função argmax()
facilita a localização do índice do valor máximo tanto em arrays unidimensionais quanto multidimensionais. Vamos explorar como ela funciona.
Descobrindo o Índice do Maior Elemento em um Array NumPy
Para acompanhar este tutorial, você precisa ter Python e NumPy instalados. Você pode começar a codificar usando um REPL Python ou um notebook Jupyter.
Primeiramente, importe o NumPy usando o alias comum np
.
import numpy as np
Você pode utilizar a função max()
do NumPy para obter o maior valor em um array (opcionalmente, ao longo de um eixo específico).
array_1 = np.array([1,5,7,2,10,9,8,4])
print(np.max(array_1))
# Saída
10
Nesse exemplo, np.max(array_1)
retorna 10, que é o valor máximo correto.
Imagine que você precise saber o índice onde o valor máximo ocorre no array. Para isso, você poderia seguir estes dois passos:
- Achar o elemento máximo.
- Localizar o índice desse elemento.
Em array_1
, o valor máximo 10 está no índice 4, seguindo a indexação que começa em zero. O primeiro elemento está no índice 0, o segundo no índice 1 e assim por diante.
Para encontrar o índice do máximo, você pode usar a função where()
do NumPy. np.where(condição)
retorna um array com todos os índices onde a condição é verdadeira.
Será necessário acessar o array e pegar o valor no primeiro índice. Para descobrir onde o máximo aparece, definimos a condição como array_1==10
, lembrando que 10 é o máximo em array_1
.
print(int(np.where(array_1==10)[0]))
# Saída
4
Usamos np.where()
apenas com a condição, mas essa não é a forma mais recomendada de usar essa função.
📑 Nota: Função where()
do NumPy:
np.where(condição, x, y)
retorna:
- Elementos de
x
quando a condição é verdadeira, e - Elementos de
y
quando a condição é falsa.
Assim, encadeando as funções np.max()
e np.where()
, podemos encontrar o maior elemento e seu índice.
Contudo, em vez desse processo de duas etapas, a função argmax()
do NumPy fornece o índice do elemento máximo diretamente.
Sintaxe da Função argmax()
do NumPy
A forma geral de usar a função argmax()
do NumPy é:
np.argmax(array, axis, out)
# importamos o numpy como np
Na sintaxe acima:
array
é qualquer array NumPy válido.axis
é um parâmetro opcional. Em arrays multidimensionais, permite encontrar o índice do máximo ao longo de um eixo específico.out
também é opcional. Você pode definirout
para um array NumPy, onde o resultado da funçãoargmax()
será armazenado.
Nota: A partir da versão 1.22.0 do NumPy, existe um parâmetro adicional keepdims
. Ao especificar o parâmetro axis
, o array é reduzido ao longo desse eixo. Definir keepdims
como True
garante que a saída tenha a mesma forma do array de entrada.
Usando argmax()
para Encontrar o Índice do Maior Elemento
#1. Vamos utilizar a função argmax()
para encontrar o índice do maior elemento em array_1
.
array_1 = np.array([1,5,7,2,10,9,8,4])
print(np.argmax(array_1))
# Saída
4
A função argmax()
retorna 4, que está correto! ✅
#2. Se redefinirmos array_1
com 10 ocorrendo duas vezes, argmax()
retornará o índice da primeira ocorrência.
array_1 = np.array([1,5,7,2,10,10,8,4])
print(np.argmax(array_1))
# Saída
4
Nos exemplos seguintes, usaremos os elementos de array_1
definidos no exemplo #1.
Usando argmax()
em um Array 2D
Vamos transformar array_1
em um array bidimensional com duas linhas e quatro colunas.
array_2 = array_1.reshape(2,4)
print(array_2)
# Saída
[[ 1 5 7 2]
[10 9 8 4]]
Em um array bidimensional, o eixo 0 representa as linhas e o eixo 1 as colunas. Arrays NumPy utilizam indexação baseada em zero. Assim, os índices das linhas e colunas de array_2
são:
Agora, vamos chamar argmax()
em array_2
.
print(np.argmax(array_2))
# Saída
4
Apesar de chamarmos argmax()
no array bidimensional, ele ainda retorna 4. Isso é idêntico à saída para o array unidimensional array_1
.
Por que isso ocorre? 🤔
Isso acontece porque não definimos nenhum valor para o parâmetro axis
. Por padrão, quando axis
não é definido, argmax()
retorna o índice do maior elemento ao longo do array “achatado”.
O que é um array achatado? Se você tem um array N-dimensional de forma d1 x d2 x … x dN, onde d1, d2, até dN são os tamanhos nas N dimensões, o array achatado é um array unidimensional de tamanho d1 * d2 * … * dN.
Para ver a aparência do array achatado de array_2
, use o método flatten()
:
array_2.flatten()
# Saída
array([ 1, 5, 7, 2, 10, 9, 8, 4])
Índice do Maior Elemento ao Longo das Linhas (axis=0
)
Agora, vamos encontrar o índice do maior elemento ao longo das linhas (axis=0
).
np.argmax(array_2,axis=0)
# Saída
array([1, 1, 1, 1])
Essa saída pode ser um pouco difícil de entender, mas vamos analisar como ela funciona.
Definimos o parâmetro axis
como zero (axis=0
), pois queremos o índice do maior elemento ao longo das linhas. Assim, argmax()
retorna o número da linha onde o maior elemento se encontra – para cada coluna.
Vamos visualizar isso para melhor compreensão.
Do diagrama e da saída de argmax()
, temos:
- Na primeira coluna (índice 0), o maior valor 10 está na segunda linha, no índice 1.
- Na segunda coluna (índice 1), o maior valor 9 está na segunda linha, no índice 1.
- Na terceira e quarta colunas (índices 2 e 3), os maiores valores 8 e 4 também estão na segunda linha, no índice 1.
É por isso que a saída é ([1, 1, 1, 1])
, porque o maior elemento ao longo das linhas está na segunda linha (para todas as colunas).
Índice do Maior Elemento ao Longo das Colunas (axis=1
)
Agora, vamos usar argmax()
para encontrar o índice do maior elemento ao longo das colunas.
Execute o código abaixo e observe a saída.
np.argmax(array_2,axis=1)
array([2, 0])
Consegue entender a saída?
Definimos axis=1
para calcular o índice do maior elemento ao longo das colunas.
A função argmax()
retorna, para cada linha, o número da coluna onde o maior valor ocorre.
Aqui está uma explicação visual:
Do diagrama e da saída de argmax()
, temos:
- Na primeira linha (índice 0), o maior valor 7 está na terceira coluna, no índice 2.
- Na segunda linha (índice 1), o maior valor 10 está na primeira coluna, no índice 0.
Esperamos que agora você entenda que array([2, 0])
significa isso.
Usando o Parâmetro Opcional out
em argmax()
Você pode usar o parâmetro opcional out
na função argmax()
para armazenar a saída em um array NumPy.
Vamos criar um array de zeros para guardar a saída da chamada anterior de argmax()
— para encontrar o índice do maior elemento ao longo das colunas (axis=1
).
out_arr = np.zeros((2,))
print(out_arr)
[0. 0.]
Agora, vamos revisitar o exemplo de encontrar o índice do maior elemento ao longo das colunas (axis=1
) e definir out
como out_arr
que criamos acima.
np.argmax(array_2,axis=1,out=out_arr)
O interpretador Python retorna um TypeError
, pois out_arr
foi inicializado como um array de floats por padrão.
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds)
56 try:
---> 57 return bound(*args, **kwds)
58 except TypeError:
TypeError: Cannot cast array data from dtype('float64') to dtype('int64') according to the rule 'safe'
Ao definir o parâmetro out
para o array de saída, é fundamental garantir que o array tenha a forma e o tipo de dados corretos. Como os índices de array são sempre números inteiros, devemos definir o parâmetro dtype
como int
ao criar o array de saída.
out_arr = np.zeros((2,),dtype=int)
print(out_arr)
# Saída
[0 0]
Agora podemos chamar argmax()
com os parâmetros axis
e out
, e desta vez, ele é executado sem erros.
np.argmax(array_2,axis=1,out=out_arr)
O resultado de argmax()
agora pode ser acessado em out_arr
.
print(out_arr)
# Saída
[2 0]
Conclusão
Esperamos que este guia tenha ajudado você a entender como usar a função argmax()
do NumPy. Você pode testar os exemplos de código em um notebook Jupyter.
Vamos revisar o que aprendemos.
- A função
argmax()
do NumPy retorna o índice do maior elemento em um array. Se o maior elemento aparecer mais de uma vez,np.argmax(a)
retorna o índice da primeira ocorrência. - Em arrays multidimensionais, você pode usar o parâmetro opcional
axis
para obter o índice do maior elemento ao longo de um eixo específico. Por exemplo, em um array bidimensional:axis=0
eaxis=1
para obter o índice do maior elemento ao longo de linhas e colunas, respectivamente. - Para guardar o resultado em outro array, use o parâmetro
out
. No entanto, o array de saída precisa ter o formato e tipo de dados corretos.
Em seguida, confira o guia detalhado sobre conjuntos (sets) em Python.