argmax
是 NumPy 和许多其他科学计算库(如 PyTorch、TensorFlow)中的一个非常常用的函数,它的作用是返回数组中最大值的索引。
简单来说,argmax
告诉你最大值在哪里,而不是最大值是多少。
argmax
的基本用法
np.argmax(a, axis=None, out=None)
a
: 你要查找最大值索引的数组。axis
: (可选)指定在哪个维度上查找。None
(默认值):在整个数组中查找最大值的索引。- 一个整数:指定维度。例如,
axis=0
表示按列查找,axis=1
表示按行查找。
1. 在一维数组中的应用
对于一维数组,argmax
会返回单个索引。
import numpy as npscores = np.array([85, 90, 78, 92, 88])
# 数组中最大值是 92,它的索引是 3 (从 0 开始)
max_index = np.argmax(scores)print(f"最大值的索引是: {max_index}")
# 输出: 最大值的索引是: 3
print(f"最大值是: {scores[max_index]}")
# 输出: 最大值是: 92
2. 在二维数组中的应用
在二维数组中,axis
参数变得非常重要。
import numpy as np# 假设这是一个 3x4 的数组
grades = np.array([[80, 85, 90, 75],[95, 88, 92, 90],[70, 75, 80, 85]])# 在整个数组中查找最大值的索引
# 先将数组展平为一维,再查找
overall_max_index = np.argmax(grades)
print(f"整个数组最大值的索引 (展平后): {overall_max_index}")
# 输出: 整个数组最大值的索引 (展平后): 4# 按列查找最大值的索引 (axis=0)
max_per_column = np.argmax(grades, axis=0)
print(f"每列最大值的索引: {max_per_column}")
# 输出: 每列最大值的索引: [1 0 1 2]
# 解释: 第0列最大值95在索引1,第1列最大值85在索引0,以此类推。# 按行查找最大值的索引 (axis=1)
max_per_row = np.argmax(grades, axis=1)
print(f"每行最大值的索引: {max_per_row}")
# 输出: 每行最大值的索引: [2 0 3]
# 解释: 第0行最大值90在索引2,第1行最大值95在索引0,以此类推。
在机器学习中的应用
argmax
在机器学习中非常常见,尤其是在多分类任务中。
一个典型的神经网络会对每个可能的类别输出一个分数或概率。为了做出最终的预测,我们需要找出哪个类别的分数最高。
在你的手写数字识别代码中,np.argmax(pre, axis=1)
的作用就是:
pre
是一个形状为(1, 10)
的二维数组,其中每个值代表模型对 0-9 这 10 个数字的原始预测分数。axis=1
告诉argmax
在这个数组的行方向上查找最大值。- 它返回最大值所在位置的索引。这个索引,就是模型最终预测的数字。
例如,如果 pre
是 [[ -1.5, 0.2, 3.8, ... ]]
,np.argmax(pre, axis=1)
会返回 [2]
,这告诉我们模型预测的数字是 2。