zmodn-rs

A simple Rust library for integers modulo N
git clone https://git.tronto.net/zmodn-rs
Download | Log | Files | Refs | README

lib.rs (7838B)


      1 use std::fmt;
      2 use std::ops;
      3 
      4 pub type BaseInt = i64;
      5 
      6 // We assume canonical representative, can compare value for PartialEq
      7 #[derive(Copy, Clone, Debug, PartialEq)]
      8 pub struct Zmod<const N: BaseInt> {
      9     value: BaseInt
     10 }
     11 
     12 fn canonical_rep<const N: BaseInt>(x: BaseInt) -> BaseInt {
     13     return (x % N + N) % N;
     14 }
     15 
     16 fn extended_gcd(a: BaseInt, b: BaseInt) -> (BaseInt, BaseInt, BaseInt) {
     17     if b == 0 {
     18         return (a, 1, 0);
     19     }
     20     let (g, x, y) = extended_gcd(b, a%b);
     21     (g, y, x - y*(a/b))
     22 }
     23 
     24 impl<const N: BaseInt> Zmod<N> {
     25     pub fn from(x: BaseInt) -> Zmod<N> {
     26         #[cfg(debug_assertions)]
     27         assert!(N > 1, "modulus must be greater than 1");
     28 
     29         Zmod::<N> { value: canonical_rep::<N>(x) }
     30     }
     31 
     32     fn inverse(self) -> Result<Zmod<N>, BaseInt> {
     33         let (g, a, _) = extended_gcd(self.value, N);
     34         if g == 1 {
     35             Ok(Zmod::<N>::from(a))
     36         } else {
     37             Err(g)
     38         }
     39     }
     40 }
     41 
     42 impl<const N: BaseInt> fmt::Display for Zmod<N> {
     43     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
     44         write!(f, "({} mod {})", self.value, N)
     45     }
     46 }
     47 
     48 impl<const N: BaseInt> ops::Add for Zmod<N> {
     49     type Output = Zmod<N>;
     50 
     51     fn add(self, z: Zmod<N>) -> Zmod<N> {
     52         Zmod::<N>::from(self.value + z.value)
     53     }
     54 }
     55 
     56 impl<const N: BaseInt> ops::Add<BaseInt> for Zmod<N> {
     57     type Output = Zmod<N>;
     58 
     59     fn add(self, z: BaseInt) -> Zmod<N> {
     60         Zmod::<N>::from(self.value + z)
     61     }
     62 }
     63 
     64 impl<const N: BaseInt> ops::Add<Zmod<N>> for BaseInt {
     65     type Output = Zmod::<N>;
     66 
     67     fn add(self, z: Zmod::<N>) -> Zmod<N> {
     68         Zmod::<N>::from(self + z.value)
     69     }
     70 }
     71 
     72 impl<const N: BaseInt> ops::AddAssign for Zmod<N> {
     73     fn add_assign(&mut self, z: Zmod<N>) {
     74         self.value = canonical_rep::<N>(self.value + z.value);
     75     }
     76 }
     77 
     78 impl<const N: BaseInt> ops::AddAssign<BaseInt> for Zmod<N> {
     79     fn add_assign(&mut self, z: BaseInt) {
     80         self.value = canonical_rep::<N>(self.value + z);
     81     }
     82 }
     83 
     84 impl<const N: BaseInt> ops::Sub for Zmod<N> {
     85     type Output = Zmod<N>;
     86 
     87     fn sub(self, z: Zmod<N>) -> Zmod<N> {
     88         Zmod::<N>::from(self.value - z.value)
     89     }
     90 }
     91 
     92 impl<const N: BaseInt> ops::Sub<BaseInt> for Zmod<N> {
     93     type Output = Zmod<N>;
     94 
     95     fn sub(self, z: BaseInt) -> Zmod<N> {
     96         Zmod::<N>::from(self.value - z)
     97     }
     98 }
     99 
    100 impl<const N: BaseInt> ops::Sub<Zmod<N>> for BaseInt {
    101     type Output = Zmod<N>;
    102 
    103     fn sub(self, z: Zmod<N>) -> Zmod<N> {
    104         Zmod::<N>::from(self - z.value)
    105     }
    106 }
    107 
    108 impl<const N: BaseInt> ops::SubAssign for Zmod<N> {
    109     fn sub_assign(&mut self, z: Zmod<N>) {
    110         self.value = canonical_rep::<N>(self.value - z.value);
    111     }
    112 }
    113 
    114 impl<const N: BaseInt> ops::SubAssign<BaseInt> for Zmod<N> {
    115     fn sub_assign(&mut self, z: BaseInt) {
    116         self.value = canonical_rep::<N>(self.value - z);
    117     }
    118 }
    119 
    120 impl<const N: BaseInt> ops::Neg for Zmod<N> {
    121     type Output = Zmod<N>;
    122 
    123     fn neg(self) -> Zmod<N> {
    124         Zmod::<N>::from(-self.value)
    125     }
    126 }
    127 
    128 impl<const N: BaseInt> ops::Mul for Zmod<N> {
    129     type Output = Zmod<N>;
    130 
    131     fn mul(self, z: Zmod<N>) -> Zmod<N> {
    132         Zmod::<N>::from(self.value * z.value)
    133     }
    134 }
    135 
    136 impl<const N: BaseInt> ops::Mul<BaseInt> for Zmod<N> {
    137     type Output = Zmod<N>;
    138 
    139     fn mul(self, z: BaseInt) -> Zmod<N> {
    140         Zmod::<N>::from(self.value * z)
    141     }
    142 }
    143 
    144 impl<const N: BaseInt> ops::Mul<Zmod<N>> for BaseInt {
    145     type Output = Zmod<N>;
    146 
    147     fn mul(self, z: Zmod<N>) -> Zmod<N> {
    148         Zmod::<N>::from(self * z.value)
    149     }
    150 }
    151 
    152 impl<const N: BaseInt> ops::MulAssign for Zmod<N> {
    153     fn mul_assign(&mut self, z: Zmod<N>) {
    154         self.value = canonical_rep::<N>(self.value * z.value);
    155     }
    156 }
    157 
    158 impl<const N: BaseInt> ops::MulAssign<BaseInt> for Zmod<N> {
    159     fn mul_assign(&mut self, z: BaseInt) {
    160         self.value = canonical_rep::<N>(self.value * z);
    161     }
    162 }
    163 
    164 impl<const N: BaseInt> ops::Div for Zmod<N> {
    165     type Output = Result<Zmod<N>, BaseInt>;
    166 
    167     fn div(self, z: Zmod<N>) -> Result<Zmod<N>, BaseInt> {
    168         Ok(self * z.inverse()?)
    169     }
    170 }
    171 
    172 impl<const N: BaseInt> ops::Div<BaseInt> for Zmod<N> {
    173     type Output = Result<Zmod<N>, BaseInt>;
    174 
    175     fn div(self, z: BaseInt) -> Result<Zmod<N>, BaseInt> {
    176         self / Zmod::<N>::from(z)
    177     }
    178 }
    179 
    180 impl<const N: BaseInt> ops::Div<Zmod<N>> for BaseInt {
    181     type Output = Result<Zmod<N>, BaseInt>;
    182 
    183     fn div(self, z: Zmod<N>) -> Result<Zmod<N>, BaseInt> {
    184         Zmod::<N>::from(self) / z
    185     }
    186 }
    187 
    188 #[cfg(test)]
    189 mod tests {
    190     use super::*;
    191 
    192     #[test]
    193     fn fmt_simple() {
    194         let x = Zmod::<5>::from(3);
    195         assert_eq!(x.to_string(), "(3 mod 5)");
    196     }
    197 
    198     #[test]
    199     fn two_is_zero_mod_two() {
    200         assert_eq!(Zmod::<2>::from(2), Zmod::<2>::from(0));
    201     }
    202 
    203     #[test]
    204     fn negative_one_is_one_mod_two() {
    205         assert_eq!(Zmod::<2>::from(-1), Zmod::<2>::from(1));
    206     }
    207 
    208     #[test]
    209     #[should_panic]
    210     fn negative_modulus_panic() {
    211         let _ = Zmod::<-3>::from(0);
    212     }
    213 
    214     #[test]
    215     #[should_panic]
    216     fn modulus_one_panic() {
    217         let _ = Zmod::<1>::from(0);
    218     }
    219 
    220     #[test]
    221     fn add_zmod_zmod() {
    222         let one = Zmod::<2>::from(1);
    223         assert_eq!(one + one, Zmod::<2>::from(0));
    224     }
    225 
    226     #[test]
    227     fn add_zmod_num() {
    228         let one = Zmod::<4>::from(1);
    229         assert_eq!(one + 3, Zmod::<4>::from(0));
    230     }
    231 
    232     #[test]
    233     fn add_num_zmod() {
    234         assert_eq!(3 + Zmod::<9>::from(-4), Zmod::<9>::from(8));
    235     }
    236 
    237     #[test]
    238     fn add_assign_zmod_zmod() {
    239         let mut x = Zmod::<7>::from(-2);
    240         x += Zmod::<7>::from(25);
    241         assert_eq!(x, Zmod::<7>::from(2));
    242     }
    243 
    244     #[test]
    245     fn add_assign_zmod_num() {
    246         let mut x = Zmod::<3>::from(2);
    247         x += 2;
    248         assert_eq!(x, Zmod::<3>::from(1));
    249     }
    250 
    251     #[test]
    252     fn subtract_zmod_zmod() {
    253         let x = Zmod::<5>::from(2);
    254         let y = Zmod::<5>::from(-4);
    255         assert_eq!(x - y, Zmod::<5>::from(1));
    256     }
    257 
    258     #[test]
    259     fn subtract_zmod_num() {
    260         assert_eq!(Zmod::<3>::from(1) - 2, Zmod::<3>::from(-1));
    261     }
    262 
    263     #[test]
    264     fn subtract_num_zmod() {
    265         assert_eq!(2 - Zmod::<7>::from(5), Zmod::<7>::from(4));
    266     }
    267 
    268     #[test]
    269     fn subtract_assign_zmod_zmod() {
    270         let mut x = Zmod::<15>::from(32);
    271         x -= Zmod::<15>::from(12);
    272         assert_eq!(x, Zmod::<15>::from(20));
    273     }
    274 
    275     #[test]
    276     fn subtract_assign_zmod_num() {
    277         let mut x = Zmod::<17>::from(11);
    278         x -= 20;
    279         assert_eq!(x, Zmod::<17>::from(-9));
    280     }
    281 
    282     #[test]
    283     fn neg() {
    284         assert_eq!(-Zmod::<14>::from(18), Zmod::<14>::from(-18));
    285     }
    286 
    287     #[test]
    288     fn multiply_zmod_zmod() {
    289         let x = Zmod::<9>::from(6);
    290         let y = Zmod::<9>::from(-3);
    291         assert_eq!(x * y, Zmod::<9>::from(0));
    292     }
    293 
    294     #[test]
    295     fn multiply_zmod_num() {
    296         assert_eq!(Zmod::<5>::from(4) * 2, Zmod::<5>::from(3));
    297     }
    298 
    299     #[test]
    300     fn multiply_num_zmod() {
    301         assert_eq!(6 * Zmod::<7>::from(5), Zmod::<7>::from(30));
    302     }
    303 
    304     #[test]
    305     fn multiply_assign_zmod_zmod() {
    306         let mut x = Zmod::<15>::from(3);
    307         x *= Zmod::<15>::from(7);
    308         assert_eq!(x, Zmod::<15>::from(6));
    309     }
    310 
    311     #[test]
    312     fn multiply_assign_zmod_num() {
    313         let mut x = Zmod::<17>::from(11);
    314         x *= 4;
    315         assert_eq!(x, Zmod::<17>::from(10));
    316     }
    317 
    318     #[test]
    319     fn inverse_one() {
    320         let x = Zmod::<44>::from(1);
    321         assert_eq!(x.inverse(), Ok(x));
    322     }
    323 
    324     #[test]
    325     fn inverse_12_35() {
    326         let x = Zmod::<35>::from(12);
    327         assert_eq!(x.inverse(), Ok(Zmod::<35>::from(3)));
    328     }
    329 
    330     #[test]
    331     fn inverse_fail() {
    332         let x = Zmod::<9>::from(6);
    333         assert_eq!(x.inverse(), Err(3));
    334     }
    335 
    336     #[test]
    337     fn divide_success() {
    338         let x = Zmod::<35>::from(15);
    339         let d = Zmod::<35>::from(6);
    340         assert_eq!(x / d, Ok(Zmod::<35>::from(20)));
    341     }
    342 }