Neste tutorial, você aprenderá como usar a função NumPy argmax() para encontrar o índice do elemento máximo em arrays.
NumPy é uma biblioteca poderosa para computação científica em Python; ele fornece arrays N-dimensionais com melhor desempenho do que as listas do Python. Uma das operações comuns que você realizará ao trabalhar com matrizes NumPy é encontrar o valor máximo na matriz. No entanto, às vezes você pode querer localizar o índice no qual ocorre o valor máximo.
A função argmax() ajuda a encontrar o índice do máximo em arrays unidimensionais e multidimensionais. Vamos continuar a aprender como funciona.
últimas postagens
Como encontrar o índice do elemento máximo em uma matriz NumPy
Para acompanhar este tutorial, você precisa ter o Python e o NumPy instalados. Você pode codificar iniciando um Python REPL ou iniciando um notebook Jupyter.
Primeiro, vamos importar o NumPy sob o alias usual np.
import numpy as np
Você pode usar a função NumPy max() para obter o valor máximo em uma matriz (opcionalmente ao longo de um eixo específico).
array_1 = np.array([1,5,7,2,10,9,8,4]) print(np.max(array_1)) # Output 10
Nesse caso, np.max(array_1) retorna 10, o que está correto.
Suponha que você queira encontrar o índice no qual o valor máximo ocorre na matriz. Você pode adotar a seguinte abordagem em duas etapas:
Em array_1, o valor máximo de 10 ocorre no índice 4, seguindo a indexação zero. O primeiro elemento está no índice 0; o segundo elemento está no índice 1 e assim por diante.
Para encontrar o índice no qual ocorre o máximo, você pode usar a função NumPy where(). np.where(condition) retorna um array de todos os índices onde a condição é True.
Você terá que tocar no array e acessar o item no primeiro índice. Para descobrir onde ocorre o valor máximo, definimos a condição para array_1==10; lembre-se de que 10 é o valor máximo em array_1.
print(int(np.where(array_1==10)[0])) # Output 4
Usamos np.where() apenas com a condição, mas este não é o método recomendado para usar esta função.
📑 Nota: NumPy where() Função:
np.where(condição,x,y) retorna:
– Elementos de x quando a condição for True, e
– Elementos de y quando a condição for False.
Portanto, encadeando as funções np.max() e np.where(), podemos encontrar o elemento máximo, seguido do índice em que ele ocorre.
Em vez do processo de duas etapas acima, você pode usar a função NumPy argmax() para obter o índice do elemento máximo na matriz.
Sintaxe da função NumPy argmax()
A sintaxe geral para usar a função NumPy argmax() é a seguinte:
np.argmax(array,axis,out) # we've imported numpy under the alias np
Na sintaxe acima:
- array é qualquer array NumPy válido.
- axis é um parâmetro opcional. Ao trabalhar com matrizes multidimensionais, você pode usar o parâmetro axis para encontrar o índice de máximo ao longo de um eixo específico.
- out é outro parâmetro opcional. Você pode definir o parâmetro out para um array NumPy para armazenar a saída da função argmax().
Nota: A partir da versão 1.22.0 do NumPy, há um parâmetro keepdims adicional. Quando especificamos o parâmetro axis na chamada da função argmax(), a matriz é reduzida ao longo desse eixo. Mas definir o parâmetro keepdims como True garante que a saída retornada tenha a mesma forma que a matriz de entrada.
Usando NumPy argmax() para encontrar o índice do elemento máximo
#1. Vamos usar a função NumPy argmax() para encontrar o índice do elemento máximo em array_1.
array_1 = np.array([1,5,7,2,10,9,8,4]) print(np.argmax(array_1)) # Output 4
A função argmax() retorna 4, o que está correto! ✅
#2. Se redefinirmos array_1 de forma que 10 ocorra duas vezes, a função argmax() retornará apenas o índice da primeira ocorrência.
array_1 = np.array([1,5,7,2,10,10,8,4]) print(np.argmax(array_1)) # Output 4
Para o restante dos exemplos, usaremos os elementos de array_1 que definimos no exemplo #1.
Usando NumPy argmax() para encontrar o índice do elemento máximo em uma matriz 2D
Vamos remodelar o array NumPy array_1 em um array bidimensional com duas linhas e quatro colunas.
array_2 = array_1.reshape(2,4) print(array_2) # Output [[ 1 5 7 2] [10 9 8 4]]
Para uma matriz bidimensional, o eixo 0 denota as linhas e o eixo 1 denota as colunas. As matrizes NumPy seguem a indexação zero. Portanto, os índices das linhas e colunas da matriz NumPy array_2 são os seguintes:
Agora, vamos chamar a função argmax() no array bidimensional, array_2.
print(np.argmax(array_2)) # Output 4
Embora tenhamos chamado argmax() no array bidimensional, ele ainda retorna 4. Isso é idêntico à saída para o array unidimensional, array_1 da seção anterior.
Por que isso acontece? 🤔
Isso ocorre porque não especificamos nenhum valor para o parâmetro axis. Quando este parâmetro de eixo não está definido, por padrão, a função argmax() retorna o índice do elemento máximo ao longo do array nivelado.
O que é uma matriz achatada? Se houver uma matriz N-dimensional de forma d1 x d2 x … x dN, onde d1, d2, até dN são os tamanhos da matriz ao longo das N dimensões, então a matriz achatada é uma longa matriz unidimensional de tamanho d1 * d2 * … * dN.
Para verificar a aparência do array achatado para array_2, você pode chamar o método flatten(), conforme mostrado abaixo:
array_2.flatten() # Output array([ 1, 5, 7, 2, 10, 9, 8, 4])
Índice do elemento máximo ao longo das linhas (eixo = 0)
Vamos prosseguir para encontrar o índice do elemento máximo ao longo das linhas (eixo = 0).
np.argmax(array_2,axis=0) # Output array([1, 1, 1, 1])
Essa saída pode ser um pouco difícil de compreender, mas vamos entender como ela funciona.
Definimos o parâmetro axis como zero (axis = 0), pois gostaríamos de encontrar o índice do elemento máximo ao longo das linhas. Portanto, a função argmax() retorna o número da linha em que o elemento máximo ocorre — para cada uma das três colunas.
Vamos visualizar isso para melhor compreensão.
Do diagrama acima e da saída argmax(), temos o seguinte:
- Para a primeira coluna no índice 0, o valor máximo 10 ocorre na segunda linha, no índice = 1.
- Para a segunda coluna no índice 1, o valor máximo 9 ocorre na segunda linha, no índice = 1.
- Para a terceira e quarta colunas no índice 2 e 3, os valores máximos 8 e 4 ocorrem na segunda linha, no índice = 1.
É precisamente por isso que temos o array de saída ([1, 1, 1, 1]) porque o elemento máximo ao longo das linhas ocorre na segunda linha (para todas as colunas).
Índice do Elemento Máximo ao Longo das Colunas (eixo = 1)
Em seguida, vamos usar a função argmax() para encontrar o índice do elemento máximo ao longo das colunas.
Execute o trecho de código a seguir e observe a saída.
np.argmax(array_2,axis=1)
array([2, 0])
Você pode analisar a saída?
Definimos axis = 1 para calcular o índice do elemento máximo ao longo das colunas.
A função argmax() retorna, para cada linha, o número da coluna em que ocorre o valor máximo.
Aqui está uma explicação visual:
Do diagrama acima e da saída argmax(), temos o seguinte:
- Para a primeira linha no índice 0, o valor máximo 7 ocorre na terceira coluna, no índice = 2.
- Para a segunda linha no índice 1, o valor máximo 10 ocorre na primeira coluna, no índice = 0.
Espero que agora você entenda qual é a saída, array([2, 0]) significa.
Usando o parâmetro opcional de saída no NumPy argmax ()
Você pode usar o parâmetro opcional out the na função NumPy argmax() para armazenar a saída em uma matriz NumPy.
Vamos inicializar uma matriz de zeros para armazenar a saída da chamada de função argmax() anterior – para encontrar o índice do máximo ao longo das colunas (eixo= 1).
out_arr = np.zeros((2,)) print(out_arr) [0. 0.]
Agora, vamos revisitar o exemplo de encontrar o índice do elemento máximo ao longo das colunas (eixo = 1) e definir o out como out_arr que definimos acima.
np.argmax(array_2,axis=1,out=out_arr)
Vemos que o interpretador Python lança um TypeError, pois o out_arr foi inicializado para um array de floats por padrão.
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) /usr/local/lib/python3.7/dist-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds) 56 try: ---> 57 return bound(*args, **kwds) 58 except TypeError: TypeError: Cannot cast array data from dtype('float64') to dtype('int64') according to the rule 'safe'
Portanto, ao definir o parâmetro out para a matriz de saída, é importante garantir que a matriz de saída tenha a forma e o tipo de dados corretos. Como os índices de array são sempre inteiros, devemos definir o parâmetro dtype como int ao definir o array de saída.
out_arr = np.zeros((2,),dtype=int) print(out_arr) # Output [0 0]
Agora podemos ir em frente e chamar a função argmax() com os parâmetros axis e out e, desta vez, ela é executada sem erros.
np.argmax(array_2,axis=1,out=out_arr)
A saída da função argmax() agora pode ser acessada no array out_arr.
print(out_arr) # Output [2 0]
Conclusão
Espero que este tutorial tenha ajudado você a entender como usar a função NumPy argmax(). Você pode executar os exemplos de código em um notebook Jupyter.
Vamos rever o que aprendemos.
- A função NumPy argmax() retorna o índice do elemento máximo em uma matriz. Se o elemento máximo ocorrer mais de uma vez em uma matriz a, então np.argmax(a) retornará o índice da primeira ocorrência do elemento.
- Ao trabalhar com matrizes multidimensionais, você pode usar o parâmetro opcional axis para obter o índice do elemento máximo ao longo de um eixo específico. Por exemplo, em uma matriz bidimensional: definindo axis = 0 e axis = 1, você pode obter o índice do elemento máximo ao longo das linhas e colunas, respectivamente.
- Se você quiser armazenar o valor retornado em outro array, você pode definir o parâmetro opcional out para o array de saída. No entanto, a matriz de saída deve ter formato e tipo de dados compatíveis.
Em seguida, confira o guia detalhado sobre conjuntos de Python.