a.py (892B)
1 import fileinput 2 3 N = 1000 # Change to 10 for test case 4 5 class SetJoin: 6 def __init__(self, n): 7 self._parent = [i for i in range(n)] 8 9 def rep(self, i): 10 if self._parent[i] != i: 11 self._parent[i] = self.rep(self._parent[i]) 12 return self._parent[i] 13 14 15 def join(self, i, j): 16 self._parent[self.rep(i)] = self.rep(j) 17 18 def sizes(self): 19 sizes = [0] * len(self._parent) 20 for i in range(len(self._parent)): 21 sizes[self.rep(i)] += 1 22 return sizes 23 24 with fileinput.input() as lines: 25 pts = [tuple(int(x) for x in line[:-1].split(',')) for line in lines] 26 27 def dist(p, q): 28 return (p[0]-q[0])**2 + (p[1]-q[1])**2 + (p[2]-q[2])**2 29 30 r = range(len(pts)) 31 d = sorted([(dist(pts[i], pts[j]), i, j) for i in r for j in r if j > i]) 32 33 sj = SetJoin(len(pts)) 34 for i in range(N): 35 j, k = d[i][1], d[i][2] 36 if sj.rep(j) != sj.rep(k): 37 sj.join(j, k) 38 39 s = sorted(sj.sizes()) 40 print(s[-1] * s[-2] * s[-3])