b.py (683B)
1 import fileinput 2 3 class SetJoin: 4 def __init__(self, n): 5 self._parent = [i for i in range(n)] 6 7 def rep(self, i): 8 if self._parent[i] != i: 9 self._parent[i] = self.rep(self._parent[i]) 10 return self._parent[i] 11 12 def join(self, i, j): 13 self._parent[self.rep(i)] = self.rep(j) 14 15 with fileinput.input() as lines: 16 pts = [tuple(int(x) for x in line[:-1].split(',')) for line in lines] 17 18 def dist(p, q): 19 return (p[0]-q[0])**2 + (p[1]-q[1])**2 + (p[2]-q[2])**2 20 21 r = range(len(pts)) 22 d = sorted([(dist(pts[i], pts[j]), i, j) for i in r for j in r if j > i]) 23 24 sj = SetJoin(len(pts)) 25 for _, j, k in d: 26 if sj.rep(j) != sj.rep(k): 27 sj.join(j, k) 28 sol = pts[j][0] * pts[k][0] 29 30 print(sol)