class UnionFind:
def __init__(self, n):
self.p = list(range(n))
self.size = [1] * n
def find(self, x):
if self.p[x] != x:
self.p[x] = self.find(self.p[x])
return self.p[x]
def union(self, a, b):
pa, pb = self.find(a), self.find(b)
if pa != pb:
if self.size[pa] > self.size[pb]:
self.p[pb] = pa
self.size[pa] += self.size[pb]
else:
self.p[pa] = pb
self.size[pb] += self.size[pa]
def reset(self, x):
self.p[x] = x
self.size[x] = 1
class Solution:
def matrixRankTransform(self, matrix: List[List[int]]) -> List[List[int]]:
m, n = len(matrix), len(matrix[0])
d = defaultdict(list)
for i, row in enumerate(matrix):
for j, v in enumerate(row):
d[v].append((i, j))
row_max = [0] * m
col_max = [0] * n
res = [[0] * n for _ in range(m)]
uf = UnionFind(m + n)
for v in sorted(d):
rank = defaultdict(int)
for i, j in d[v]:
uf.union(i, j + m)
for i, j in d[v]:
rank[uf.find(i)] = max(rank[uf.find(i)], row_max[i], col_max[j])
for i, j in d[v]:
res[i][j] = row_max[i] = col_max[j] = 1 + rank[uf.find(i)]
for i, j in d[v]:
uf.reset(i)
uf.reset(j + m)
return res