commit bb6a7fd447db8ebafe9df80097821cb0105907f7
parent 5d69f923116a2b251a445c0cb9e6f23551ad686a
Author: Sebastiano Tronto <sebastiano@tronto.net>
Date: Mon, 8 Dec 2025 08:02:19 +0100
Refactor
Diffstat:
2 files changed, 37 insertions(+), 30 deletions(-)
diff --git a/2025/08/a.py b/2025/08/a.py
@@ -2,31 +2,36 @@ import fileinput
N = 1000 # Change to 10 for test case
+class SetJoin:
+ def __init__(self, n):
+ self._parent = [i for i in range(n)]
+
+ def rep(self, i):
+ return i if self._parent[i] == i else self.rep(self._parent[i])
+
+ def join(self, i, j):
+ self._parent[self.rep(i)] = self.rep(j)
+
+ def sizes(self):
+ sizes = [0] * len(self._parent)
+ for i in range(len(self._parent)):
+ sizes[self.rep(i)] += 1
+ return sizes
+
with fileinput.input() as lines:
pts = [tuple(int(x) for x in line[:-1].split(',')) for line in lines]
-r = range(len(pts))
-
def dist(p, q):
return (p[0]-q[0])**2 + (p[1]-q[1])**2 + (p[2]-q[2])**2
+r = range(len(pts))
d = sorted([(dist(pts[i], pts[j]), i, j) for i in r for j in r if j > i])
-rep = [i for i in r]
-
-def findrep(i):
- return i if rep[i] == i else findrep(rep[i])
-
-def joinrep(i, j):
- rep[findrep(i)] = findrep(j)
-
+sj = SetJoin(len(pts))
for i in range(N):
j, k = d[i][1], d[i][2]
- if findrep(j) != findrep(k):
- joinrep(j, k)
-
-sizes = [[0, i] for i in r]
-for i in r:
- sizes[findrep(i)][0] += 1
-sizes.sort()
-print(sizes[-1][0] * sizes[-2][0] * sizes[-3][0])
+ if sj.rep(j) != sj.rep(k):
+ sj.join(j, k)
+
+s = sorted(sj.sizes())
+print(s[-1] * s[-2] * s[-3])
diff --git a/2025/08/b.py b/2025/08/b.py
@@ -1,26 +1,28 @@
import fileinput
+class SetJoin:
+ def __init__(self, n):
+ self._parent = [i for i in range(n)]
+
+ def rep(self, i):
+ return i if self._parent[i] == i else self.rep(self._parent[i])
+
+ def join(self, i, j):
+ self._parent[self.rep(i)] = self.rep(j)
+
with fileinput.input() as lines:
pts = [tuple(int(x) for x in line[:-1].split(',')) for line in lines]
-r = range(len(pts))
-
def dist(p, q):
return (p[0]-q[0])**2 + (p[1]-q[1])**2 + (p[2]-q[2])**2
+r = range(len(pts))
d = sorted([(dist(pts[i], pts[j]), i, j) for i in r for j in r if j > i])
-rep = [i for i in r]
-
-def findrep(i):
- return i if rep[i] == i else findrep(rep[i])
-
-def joinrep(i, j):
- rep[findrep(i)] = findrep(j)
-
+sj = SetJoin(len(pts))
for _, j, k in d:
- if findrep(j) != findrep(k):
- joinrep(j, k)
+ if sj.rep(j) != sj.rep(k):
+ sj.join(j, k)
sol = pts[j][0] * pts[k][0]
print(sol)