saddle_points.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def saddle_points(matrix):
    if not all(len(row) == len(matrix[0]) for row in matrix):
        raise ValueError("Irregular matrix")

    rowmax = [max(row) for row in matrix]
    colmin = [min(col) for col in zip(*matrix)]

    saddles = set()
    for m, row in enumerate(matrix):
        for n, item in enumerate(row):
            if item == rowmax[m] and item == colmin[n]:
                saddles.add((m, n))

    return saddles

Comments


You're not logged in right now. Please login via GitHub to comment