b-debug.py (3000B)
1 import fileinput 2 from math import gcd 3 4 inf = 999999999 5 6 def readline(l): 7 b = [int(x) for x in l[l.index('{')+1:l.index('}')].split(',')] 8 9 A = [[] for j in b] 10 c = [] 11 12 end = l.index(']') 13 while '(' in l[end:]: 14 begin = l[end:].index('(') + end + 1 15 end = l[begin:].index(')') + begin 16 bu = [int(x) for x in l[begin:end].split(',')] 17 c.append(min(b[j] for j in bu)) 18 for i in range(len(b)): 19 A[i].append(1 if i in bu else 0) 20 21 return A, b, c 22 23 def printabc(A, b, c): 24 print("--") 25 for i in range(len(b)): 26 print(A[i], [b[i]]) 27 print(f"Parameter bounds: {c}") 28 print("--") 29 30 def swaprow(A, b, i, j): 31 if i != j: 32 A[i], A[j] = A[j], A[i] 33 b[i], b[j] = b[j], b[i] 34 35 def swapcol(A, c, i, j): 36 if i != j: 37 for k in range(len(A)): 38 A[k][i], A[k][j] = A[k][j], A[k][i] 39 c[i], c[j] = c[j], c[i] 40 41 def reducerow(A, b, i, j): 42 if A[i][i] != 0: 43 x = A[i][i] 44 y = -A[j][i] 45 d = gcd(x, y) 46 A[j] = [(y*A[i][k]+x*A[j][k])//d for k in range(len(A[i]))] 47 b[j] = (y*b[i]+x*b[j])//d 48 49 def reduce(A, b, c): 50 for i in range(len(A[0])): 51 # Swap columns until there is one in position i with at least 52 # one non-zero element. 53 I = [] 54 k = i 55 while len(I) == 0 and k < len(A[0]): 56 swapcol(A, c, i, k) 57 I = [j for j in range(i, len(A)) if A[j][i] != 0] 58 k += 1 59 60 # If no such column is found, we are done 61 if len(I) == 0: 62 break 63 64 # Swap rows so that A[i][i] is non-zero 65 swaprow(A, b, i, I[0]) 66 67 # Reduce all other rows 68 for j in range(i+1, len(A)): 69 reducerow(A, b, i, j) 70 71 # Remove all rows of zero and check if the system is solvable 72 I = [i for i in range(len(A)) if any(a != 0 for a in A[i])] 73 if any(b[i] != 0 for i in range(len(A)) if i not in I): 74 printabc(A, b, c) 75 print("Unsolvable!") 76 exit(1) 77 A = [A[i] for i in I] 78 b = [b[i] for i in I] 79 80 # TODO continue with back substitution? 81 for i in range(len(A)-1, -1, -1): 82 for j in range(i): 83 reducerow(A, b, i, j) 84 85 # Clean all rows to minimize coefficients (unnecessary, but makes 86 # numbers smaller). 87 for i in range(len(A)): 88 d = gcd(*A[i]) * (-1 if A[i][i] < 0 else 1) 89 A[i] = [A[i][k]//d for k in range(len(A[i]))] 90 b[i] = b[i]//d 91 92 return A, b, c 93 94 def paramcomb(nparam, c): 95 if nparam == 0: 96 return [[]] 97 98 ret = [] 99 for i in range(c[-nparam]+1): 100 ret += [[i, *l] for l in paramcomb(nparam-1, c)] 101 return ret 102 103 def solve_system_min_sum(A, b, c): 104 #nparam = len(A[0])-len(A) 105 #print(f"{nparam}: {paramcomb(nparam, c)}") 106 107 k = len(A[0]) - len(A) 108 mins = inf 109 for c in paramcomb(k, c): 110 sol = sum(c) 111 for i in range(len(A)): 112 p = sum(c[j]*A[i][len(A[0])-k+j] for j in range(len(c))) 113 a = (b[i] - p)//A[i][i] 114 if a < 0 or a*A[i][i] != b[i] - p: 115 sol = inf 116 break 117 sol += a 118 mins = min(mins, sol) 119 return mins 120 121 with fileinput.input() as lines: 122 sols = [] 123 k = 1 124 for line in lines: 125 print(f"doing line {k}: {line[:-1]}") 126 A, b, c = readline(line) 127 #printabc(A, b, c) 128 A, b, c = reduce(A, b, c) 129 #printabc(A, b, c) 130 sols.append(solve_system_min_sum(A, b, c)) 131 k += 1 132 print(sols) 133 print(sum(sols))