lib.rs (7838B)
1 use std::fmt; 2 use std::ops; 3 4 pub type BaseInt = i64; 5 6 // We assume canonical representative, can compare value for PartialEq 7 #[derive(Copy, Clone, Debug, PartialEq)] 8 pub struct Zmod<const N: BaseInt> { 9 value: BaseInt 10 } 11 12 fn canonical_rep<const N: BaseInt>(x: BaseInt) -> BaseInt { 13 return (x % N + N) % N; 14 } 15 16 fn extended_gcd(a: BaseInt, b: BaseInt) -> (BaseInt, BaseInt, BaseInt) { 17 if b == 0 { 18 return (a, 1, 0); 19 } 20 let (g, x, y) = extended_gcd(b, a%b); 21 (g, y, x - y*(a/b)) 22 } 23 24 impl<const N: BaseInt> Zmod<N> { 25 pub fn from(x: BaseInt) -> Zmod<N> { 26 #[cfg(debug_assertions)] 27 assert!(N > 1, "modulus must be greater than 1"); 28 29 Zmod::<N> { value: canonical_rep::<N>(x) } 30 } 31 32 fn inverse(self) -> Result<Zmod<N>, BaseInt> { 33 let (g, a, _) = extended_gcd(self.value, N); 34 if g == 1 { 35 Ok(Zmod::<N>::from(a)) 36 } else { 37 Err(g) 38 } 39 } 40 } 41 42 impl<const N: BaseInt> fmt::Display for Zmod<N> { 43 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 44 write!(f, "({} mod {})", self.value, N) 45 } 46 } 47 48 impl<const N: BaseInt> ops::Add for Zmod<N> { 49 type Output = Zmod<N>; 50 51 fn add(self, z: Zmod<N>) -> Zmod<N> { 52 Zmod::<N>::from(self.value + z.value) 53 } 54 } 55 56 impl<const N: BaseInt> ops::Add<BaseInt> for Zmod<N> { 57 type Output = Zmod<N>; 58 59 fn add(self, z: BaseInt) -> Zmod<N> { 60 Zmod::<N>::from(self.value + z) 61 } 62 } 63 64 impl<const N: BaseInt> ops::Add<Zmod<N>> for BaseInt { 65 type Output = Zmod::<N>; 66 67 fn add(self, z: Zmod::<N>) -> Zmod<N> { 68 Zmod::<N>::from(self + z.value) 69 } 70 } 71 72 impl<const N: BaseInt> ops::AddAssign for Zmod<N> { 73 fn add_assign(&mut self, z: Zmod<N>) { 74 self.value = canonical_rep::<N>(self.value + z.value); 75 } 76 } 77 78 impl<const N: BaseInt> ops::AddAssign<BaseInt> for Zmod<N> { 79 fn add_assign(&mut self, z: BaseInt) { 80 self.value = canonical_rep::<N>(self.value + z); 81 } 82 } 83 84 impl<const N: BaseInt> ops::Sub for Zmod<N> { 85 type Output = Zmod<N>; 86 87 fn sub(self, z: Zmod<N>) -> Zmod<N> { 88 Zmod::<N>::from(self.value - z.value) 89 } 90 } 91 92 impl<const N: BaseInt> ops::Sub<BaseInt> for Zmod<N> { 93 type Output = Zmod<N>; 94 95 fn sub(self, z: BaseInt) -> Zmod<N> { 96 Zmod::<N>::from(self.value - z) 97 } 98 } 99 100 impl<const N: BaseInt> ops::Sub<Zmod<N>> for BaseInt { 101 type Output = Zmod<N>; 102 103 fn sub(self, z: Zmod<N>) -> Zmod<N> { 104 Zmod::<N>::from(self - z.value) 105 } 106 } 107 108 impl<const N: BaseInt> ops::SubAssign for Zmod<N> { 109 fn sub_assign(&mut self, z: Zmod<N>) { 110 self.value = canonical_rep::<N>(self.value - z.value); 111 } 112 } 113 114 impl<const N: BaseInt> ops::SubAssign<BaseInt> for Zmod<N> { 115 fn sub_assign(&mut self, z: BaseInt) { 116 self.value = canonical_rep::<N>(self.value - z); 117 } 118 } 119 120 impl<const N: BaseInt> ops::Neg for Zmod<N> { 121 type Output = Zmod<N>; 122 123 fn neg(self) -> Zmod<N> { 124 Zmod::<N>::from(-self.value) 125 } 126 } 127 128 impl<const N: BaseInt> ops::Mul for Zmod<N> { 129 type Output = Zmod<N>; 130 131 fn mul(self, z: Zmod<N>) -> Zmod<N> { 132 Zmod::<N>::from(self.value * z.value) 133 } 134 } 135 136 impl<const N: BaseInt> ops::Mul<BaseInt> for Zmod<N> { 137 type Output = Zmod<N>; 138 139 fn mul(self, z: BaseInt) -> Zmod<N> { 140 Zmod::<N>::from(self.value * z) 141 } 142 } 143 144 impl<const N: BaseInt> ops::Mul<Zmod<N>> for BaseInt { 145 type Output = Zmod<N>; 146 147 fn mul(self, z: Zmod<N>) -> Zmod<N> { 148 Zmod::<N>::from(self * z.value) 149 } 150 } 151 152 impl<const N: BaseInt> ops::MulAssign for Zmod<N> { 153 fn mul_assign(&mut self, z: Zmod<N>) { 154 self.value = canonical_rep::<N>(self.value * z.value); 155 } 156 } 157 158 impl<const N: BaseInt> ops::MulAssign<BaseInt> for Zmod<N> { 159 fn mul_assign(&mut self, z: BaseInt) { 160 self.value = canonical_rep::<N>(self.value * z); 161 } 162 } 163 164 impl<const N: BaseInt> ops::Div for Zmod<N> { 165 type Output = Result<Zmod<N>, BaseInt>; 166 167 fn div(self, z: Zmod<N>) -> Result<Zmod<N>, BaseInt> { 168 Ok(self * z.inverse()?) 169 } 170 } 171 172 impl<const N: BaseInt> ops::Div<BaseInt> for Zmod<N> { 173 type Output = Result<Zmod<N>, BaseInt>; 174 175 fn div(self, z: BaseInt) -> Result<Zmod<N>, BaseInt> { 176 self / Zmod::<N>::from(z) 177 } 178 } 179 180 impl<const N: BaseInt> ops::Div<Zmod<N>> for BaseInt { 181 type Output = Result<Zmod<N>, BaseInt>; 182 183 fn div(self, z: Zmod<N>) -> Result<Zmod<N>, BaseInt> { 184 Zmod::<N>::from(self) / z 185 } 186 } 187 188 #[cfg(test)] 189 mod tests { 190 use super::*; 191 192 #[test] 193 fn fmt_simple() { 194 let x = Zmod::<5>::from(3); 195 assert_eq!(x.to_string(), "(3 mod 5)"); 196 } 197 198 #[test] 199 fn two_is_zero_mod_two() { 200 assert_eq!(Zmod::<2>::from(2), Zmod::<2>::from(0)); 201 } 202 203 #[test] 204 fn negative_one_is_one_mod_two() { 205 assert_eq!(Zmod::<2>::from(-1), Zmod::<2>::from(1)); 206 } 207 208 #[test] 209 #[should_panic] 210 fn negative_modulus_panic() { 211 let _ = Zmod::<-3>::from(0); 212 } 213 214 #[test] 215 #[should_panic] 216 fn modulus_one_panic() { 217 let _ = Zmod::<1>::from(0); 218 } 219 220 #[test] 221 fn add_zmod_zmod() { 222 let one = Zmod::<2>::from(1); 223 assert_eq!(one + one, Zmod::<2>::from(0)); 224 } 225 226 #[test] 227 fn add_zmod_num() { 228 let one = Zmod::<4>::from(1); 229 assert_eq!(one + 3, Zmod::<4>::from(0)); 230 } 231 232 #[test] 233 fn add_num_zmod() { 234 assert_eq!(3 + Zmod::<9>::from(-4), Zmod::<9>::from(8)); 235 } 236 237 #[test] 238 fn add_assign_zmod_zmod() { 239 let mut x = Zmod::<7>::from(-2); 240 x += Zmod::<7>::from(25); 241 assert_eq!(x, Zmod::<7>::from(2)); 242 } 243 244 #[test] 245 fn add_assign_zmod_num() { 246 let mut x = Zmod::<3>::from(2); 247 x += 2; 248 assert_eq!(x, Zmod::<3>::from(1)); 249 } 250 251 #[test] 252 fn subtract_zmod_zmod() { 253 let x = Zmod::<5>::from(2); 254 let y = Zmod::<5>::from(-4); 255 assert_eq!(x - y, Zmod::<5>::from(1)); 256 } 257 258 #[test] 259 fn subtract_zmod_num() { 260 assert_eq!(Zmod::<3>::from(1) - 2, Zmod::<3>::from(-1)); 261 } 262 263 #[test] 264 fn subtract_num_zmod() { 265 assert_eq!(2 - Zmod::<7>::from(5), Zmod::<7>::from(4)); 266 } 267 268 #[test] 269 fn subtract_assign_zmod_zmod() { 270 let mut x = Zmod::<15>::from(32); 271 x -= Zmod::<15>::from(12); 272 assert_eq!(x, Zmod::<15>::from(20)); 273 } 274 275 #[test] 276 fn subtract_assign_zmod_num() { 277 let mut x = Zmod::<17>::from(11); 278 x -= 20; 279 assert_eq!(x, Zmod::<17>::from(-9)); 280 } 281 282 #[test] 283 fn neg() { 284 assert_eq!(-Zmod::<14>::from(18), Zmod::<14>::from(-18)); 285 } 286 287 #[test] 288 fn multiply_zmod_zmod() { 289 let x = Zmod::<9>::from(6); 290 let y = Zmod::<9>::from(-3); 291 assert_eq!(x * y, Zmod::<9>::from(0)); 292 } 293 294 #[test] 295 fn multiply_zmod_num() { 296 assert_eq!(Zmod::<5>::from(4) * 2, Zmod::<5>::from(3)); 297 } 298 299 #[test] 300 fn multiply_num_zmod() { 301 assert_eq!(6 * Zmod::<7>::from(5), Zmod::<7>::from(30)); 302 } 303 304 #[test] 305 fn multiply_assign_zmod_zmod() { 306 let mut x = Zmod::<15>::from(3); 307 x *= Zmod::<15>::from(7); 308 assert_eq!(x, Zmod::<15>::from(6)); 309 } 310 311 #[test] 312 fn multiply_assign_zmod_num() { 313 let mut x = Zmod::<17>::from(11); 314 x *= 4; 315 assert_eq!(x, Zmod::<17>::from(10)); 316 } 317 318 #[test] 319 fn inverse_one() { 320 let x = Zmod::<44>::from(1); 321 assert_eq!(x.inverse(), Ok(x)); 322 } 323 324 #[test] 325 fn inverse_12_35() { 326 let x = Zmod::<35>::from(12); 327 assert_eq!(x.inverse(), Ok(Zmod::<35>::from(3))); 328 } 329 330 #[test] 331 fn inverse_fail() { 332 let x = Zmod::<9>::from(6); 333 assert_eq!(x.inverse(), Err(3)); 334 } 335 336 #[test] 337 fn divide_success() { 338 let x = Zmod::<35>::from(15); 339 let d = Zmod::<35>::from(6); 340 assert_eq!(x / d, Ok(Zmod::<35>::from(20))); 341 } 342 }