zmodn

A simple C++ library for integers modulo N
git clone https://git.tronto.net/zmodn
Download | Log | Files | Refs | README

test (8594B)


      1 #if 0
      2 
      3 cc=${CC:-g++}
      4 bin="$(mktemp)"
      5 ${cc} -x c++ -std=c++20 -o "$bin" -g -O0 "$(realpath $0)"
      6 echo "Running $bin"
      7 "$bin"
      8 
      9 exit 0
     10 #endif
     11 
     12 #include "zmodn.h"
     13 #include "bigint.h"
     14 
     15 #include <concepts>
     16 #include <functional>
     17 #include <iostream>
     18 #include <optional>
     19 
     20 template<typename S, typename T>
     21 requires std::convertible_to<S, T> || std::convertible_to<T, S>
     22 void assert_equal(S actual, T expected) {
     23 	if (actual != expected) {
     24 		std::cout << "Error!" << std::endl;
     25 		std::cout << "Expected: " << expected << std::endl;
     26 		std::cout << "But got:  " << actual << std::endl;
     27 		exit(1);
     28 	}
     29 }
     30 
     31 class Test {
     32 public:
     33 	std::string name;
     34 	std::function<void()> f;
     35 } tests[] = {
     36 {
     37 	.name = "Constructor 2 mod 3",
     38 	.f = []() {
     39 		Zmod<3> two = Zmod<3>(2);
     40 		assert_equal(two.toint(), INT64_C(2));
     41 	}
     42 },
     43 {
     44 	.name = "Constructor -7 mod 3",
     45 	.f = []() {
     46 		Zmod<3> z = -7;
     47 		assert_equal(z, Zmod<3>(2));
     48 	}
     49 },
     50 {
     51 	.name = "1+1 mod 2",
     52 	.f = []() {
     53 		auto oneplusone = Zmod<2>(1) + Zmod<2>(1);
     54 		assert_equal(oneplusone, Zmod<2>(0));
     55 	}
     56 },
     57 {
     58 	.name = "2 -= 5 (mod 4)",
     59 	.f = []() {
     60 		Zmod<4> z = 2;
     61 		auto diff = (z -= 5);
     62 		assert_equal(z, Zmod<4>(1));
     63 		assert_equal(diff, Zmod<4>(1));
     64 	}
     65 },
     66 {
     67 	.name = "Inverse of 0 mod 2",
     68 	.f = []() {
     69 		Zmod<2> z = 0;
     70 		auto inv = z.inverse();
     71 		assert_equal(inv.has_value(), false);
     72 	}
     73 },
     74 {
     75 	.name = "Inverse of 1 mod 2",
     76 	.f = []() {
     77 		Zmod<2> z = 1;
     78 		auto inv = z.inverse();
     79 		assert_equal(inv.has_value(), true);
     80 		assert_equal(inv.value(), Zmod<2>(1));
     81 	}
     82 },
     83 {
     84 	.name = "Inverse of 5 mod 7",
     85 	.f = []() {
     86 		Zmod<7> z = 5;
     87 		auto inv = z.inverse();
     88 		assert_equal(inv.has_value(), true);
     89 		assert_equal(inv.value(), Zmod<7>(3));
     90 	}
     91 },
     92 {
     93 	.name = "Inverse of 4 mod 12",
     94 	.f = []() {
     95 		Zmod<12> z = 4;
     96 		auto inv = z.inverse();
     97 		assert_equal(inv.has_value(), false);
     98 	}
     99 },
    100 {
    101 	.name = "4 / 7 (mod 12)",
    102 	.f = []() {
    103 		Zmod<12> n = 4;
    104 		Zmod<12> d = 7;
    105 		auto inv = n / d;
    106 		assert_equal(inv.has_value(), true);
    107 		assert_equal(inv.value(), Zmod<12>(4));
    108 	}
    109 },
    110 {
    111 	.name = "4 /= 7 (mod 12)",
    112 	.f = []() {
    113 		Zmod<12> n = 4;
    114 		Zmod<12> d = 7;
    115 		auto inv = (n /= d);
    116 		assert_equal(inv.has_value(), true);
    117 		assert_equal(inv.value(), Zmod<12>(4));
    118 		assert_equal(n, Zmod<12>(4));
    119 	}
    120 },
    121 {
    122 	.name = "Multiplication overflow",
    123 	.f = []() {
    124 		Zmod<10> n = 8;
    125 		Zmod<10> m = 9;
    126 		auto prod = m * n;
    127 		assert_equal(prod.toint(), 2);
    128 	}
    129 },
    130 {
    131 	.name = "Multiplication and assignment overflow",
    132 	.f = []() {
    133 		Zmod<10> n = 8;
    134 		Zmod<10> m = 9;
    135 		n *= m;
    136 		assert_equal(n.toint(), 2);
    137 	}
    138 },
    139 {
    140 	.name = "2^12 mod 154",
    141 	.f = []() {
    142 		Zmod<154> b = 2;
    143 		auto p = b ^ 12;
    144 		assert_equal(p, 92);
    145 	}
    146 },
    147 {
    148 	.name = "BigInt constructor zero",
    149 	.f = []() {
    150 		BigInt x;
    151 		BigInt y(0);
    152 
    153 		assert_equal(x, y);
    154 	}
    155 },
    156 {
    157 	.name = "BigInt constructor one digit",
    158 	.f = []() {
    159 		BigInt x(12345);
    160 		BigInt y("12345");
    161 
    162 		assert_equal(x, y);
    163 	}
    164 },
    165 {
    166 	.name = "BigInt constructor many small digits",
    167 	.f = []() {
    168 		BigInt<20, 2> x(123456789);
    169 		BigInt<20, 2> y("123456789");
    170 
    171 		assert_equal(x, y);
    172 	}
    173 },
    174 {
    175 	.name = "BigInt constructor negative many small digits",
    176 	.f = []() {
    177 		BigInt<20, 2> x(-123456789);
    178 		BigInt<20, 2> y("-123456789");
    179 
    180 		assert_equal(x, y);
    181 	}
    182 },
    183 {
    184 	.name = "BigInt operator==",
    185 	.f = []() {
    186 		BigInt<20, 2> x(123456789);
    187 		BigInt<20, 2> y("123456789");
    188 		BigInt<20, 2> z("12456789");
    189 
    190 		bool eq = (x == y);
    191 		bool diff = (x == z);
    192 
    193 		assert_equal(eq, true);
    194 		assert_equal(diff, false);
    195 	},
    196 },
    197 {
    198 	.name = "BigInt operator== negative",
    199 	.f = []() {
    200 		BigInt<20, 2> x("-123456789");
    201 		BigInt<20, 2> z("123456789");
    202 
    203 		bool diff = (x == z);
    204 
    205 		assert_equal(diff, false);
    206 	},
    207 },
    208 {
    209 	.name = "BigInt operator!= true",
    210 	.f = []() {
    211 		BigInt<20, 2> x(12345678);
    212 		BigInt<20, 2> y("123456789");
    213 		BigInt<20, 2> z("123456789");
    214 
    215 		bool diff = (x != y);
    216 		bool eq = (y != z);
    217 
    218 		assert_equal(diff, true);
    219 		assert_equal(eq, false);
    220 	},
    221 },
    222 {
    223 	.name = "BigInt operator< and operator>",
    224 	.f = []() {
    225 		BigInt<20, 2> x(7891);
    226 		BigInt<20, 2> y(7881);
    227 	
    228 		bool t = (y < x);
    229 		bool f = (x < y);
    230 
    231 		assert_equal(t, true);
    232 		assert_equal(f, false);
    233 	}
    234 },
    235 {
    236 	.name = "BigInt operator< both negative",
    237 	.f = []() {
    238 		BigInt<20, 2> x(-7891);
    239 		BigInt<20, 2> y(-7881);
    240 	
    241 		bool cmp = (x < y);
    242 
    243 		assert_equal(cmp, true);
    244 	}
    245 },
    246 {
    247 	.name = "BigInt operator< different sign",
    248 	.f = []() {
    249 		BigInt<20, 2> x(-7);
    250 		BigInt<20, 2> y(7);
    251 	
    252 		bool cmp = (x < y);
    253 
    254 		assert_equal(cmp, true);
    255 	}
    256 },
    257 {
    258 	.name = "BigInt abs",
    259 	.f = []() {
    260 		BigInt<20, 2> x(-1234567);
    261 		BigInt<20, 2> y(7654321);
    262 
    263 		assert_equal(x.abs(), BigInt<20, 2>(1234567));
    264 		assert_equal(y.abs(), y);
    265 	}
    266 },
    267 {
    268 	.name = "BigInt opposite",
    269 	.f = []() {
    270 		BigInt<20, 2> x(-1234567);
    271 		BigInt<20, 2> y(7654321);
    272 
    273 		assert_equal(-x, BigInt<20, 2>(1234567));
    274 		assert_equal(-y, BigInt<20, 2>(-7654321));
    275 	}
    276 },
    277 {
    278 	.name = "BigInt -0 == 0",
    279 	.f = []() {
    280 		BigInt z(0);
    281 
    282 		assert_equal(-z, z);
    283 	}
    284 },
    285 {
    286 	.name = "BigInt sum",
    287 	.f = []() {
    288 		BigInt<20, 2> x("987608548588589");
    289 		BigInt<20, 2> y("6793564545455289");
    290 		BigInt<20, 2> z("7781173094043878");
    291 
    292 		assert_equal(x+y, z);
    293 	}
    294 },
    295 {
    296 	.name = "BigInt sum both negative",
    297 	.f = []() {
    298 		BigInt<20, 2> x("-987608548588589");
    299 		BigInt<20, 2> y("-6793564545455289");
    300 		BigInt<20, 2> z("-7781173094043878");
    301 
    302 		assert_equal(x+y, z);
    303 	}
    304 },
    305 {
    306 	.name = "BigInt sum negative + positive, result positive",
    307 	.f = []() {
    308 		BigInt<20, 2> x("-987608548588589");
    309 		BigInt<20, 2> y("6793564545455289");
    310 		BigInt<20, 2> z("5805955996866700");
    311 
    312 		assert_equal(x+y, z);
    313 	}
    314 },
    315 {
    316 	.name = "BigInt sum positive + negative, result negative",
    317 	.f = []() {
    318 		BigInt<20, 2> x("987608548588589");
    319 		BigInt<20, 2> y("-6793564545455289");
    320 		BigInt<20, 2> z("-5805955996866700");
    321 
    322 		assert_equal(x+y, z);
    323 	}
    324 },
    325 {
    326 	.name = "BigInt difference",
    327 	.f = []() {
    328 		BigInt<20, 2> x("2342442323434134");
    329 		BigInt<20, 2> y("2524342523342342");
    330 		BigInt<20, 2> z("-181900199908208");
    331 
    332 		assert_equal(x-y, z);
    333 	}
    334 },
    335 {
    336 	.name = "BigInt product",
    337 	.f = []() {
    338 		BigInt<100, 3> x("134142345244134");
    339 		BigInt<100, 3> y("-56543047058245");
    340 		BigInt<100, 3> z("-7584816939642416135042584830");
    341 
    342 		assert_equal(x*y, z);
    343 	}
    344 },
    345 {
    346 	.name = "BigInt operator+=",
    347 	.f = []() {
    348 		BigInt<20, 2> x("987608548588589");
    349 		BigInt<20, 2> y("6793564545455289");
    350 		BigInt<20, 2> z("7781173094043878");
    351 
    352 		x += y;
    353 
    354 		assert_equal(x, z);
    355 	}
    356 },
    357 {
    358 	.name = "BigInt 14 / 3",
    359 	.f = []() {
    360 		BigInt x(14);
    361 		BigInt y(3);
    362 
    363 		assert_equal(x / y, 4);
    364 	}
    365 },
    366 {
    367 	.name = "BigInt 14 % 3",
    368 	.f = []() {
    369 		BigInt x(14);
    370 		BigInt y(3);
    371 
    372 		assert_equal(x % y, 2);
    373 	}
    374 },
    375 {
    376 	.name = "BigInt 14 / -3",
    377 	.f = []() {
    378 		BigInt x(14);
    379 		BigInt y(-3);
    380 
    381 		assert_equal(x / y, -5);
    382 	}
    383 },
    384 {
    385 	.name = "BigInt 14 % -3",
    386 	.f = []() {
    387 		BigInt x(14);
    388 		BigInt y(-3);
    389 
    390 		assert_equal(x % y, -1);
    391 	}
    392 },
    393 {
    394 	.name = "BigInt -14 / 3",
    395 	.f = []() {
    396 		BigInt x(-14);
    397 		BigInt y(3);
    398 
    399 		assert_equal(x / y, -5);
    400 	}
    401 },
    402 {
    403 	.name = "BigInt -14 % 3",
    404 	.f = []() {
    405 		BigInt x(-14);
    406 		BigInt y(3);
    407 
    408 		assert_equal(x % y, 1);
    409 	}
    410 },
    411 {
    412 	.name = "BigInt -14 / -3",
    413 	.f = []() {
    414 		BigInt x(-14);
    415 		BigInt y(-3);
    416 
    417 		assert_equal(x / y, 4);
    418 	}
    419 },
    420 {
    421 	.name = "BigInt -14 % -3",
    422 	.f = []() {
    423 		BigInt x(-14);
    424 		BigInt y(-3);
    425 
    426 		assert_equal(x % y, -2);
    427 	}
    428 },
    429 {
    430 	.name = "BigInt division large numbers, quotient = 0",
    431 	.f = []() {
    432 		BigInt<50, 3> x("4534435234134244242");
    433 		BigInt<50, 3> y("7832478748237487343");
    434 
    435 		assert_equal(x / y, 0);
    436 	}
    437 },
    438 {
    439 	.name = "BigInt division large numbers",
    440 	.f = []() {
    441 		BigInt<50, 3> x("12344534435234134244242");
    442 		BigInt<50, 3> y("7832478748237487343");
    443 		BigInt<50, 3> z(1576);
    444 
    445 		assert_equal(x / y, z);
    446 	}
    447 },
    448 {
    449 	.name = "BigInt modulo large numbers",
    450 	.f = []() {
    451 		BigInt<50, 3> x("12344534435234134244242");
    452 		BigInt<50, 3> y("7832478748237487343");
    453 		BigInt<50, 3> z("547928011854191674");
    454 
    455 		assert_equal(x % y, z);
    456 	}
    457 },
    458 {
    459 	.name = "Zmod with BigInt constructor",
    460 	.f = []() {
    461 		constexpr BigInt<50, 3> N("78923471");
    462 		constexpr BigInt<50, 3> x("145452451");
    463 		Zmod<N> xmodN(x);
    464 
    465 		assert_equal(xmodN.toint(), x % N);
    466 	}
    467 },
    468 {
    469 	.name = "Zmod with BigInt big inverse",
    470 	.f = []() {
    471 		constexpr BigInt<50, 3> N("7520824651249795349285");
    472 		constexpr BigInt<50, 3> x("234589234599896924596");
    473 		constexpr BigInt<50, 3> expected("5901181270843786267351");
    474 		Zmod<N> xmodN(x);
    475 
    476 		auto inv = xmodN.inverse();
    477 
    478 		assert_equal(inv.has_value(), true);
    479 		assert_equal(inv.value().toint(), expected);
    480 	}
    481 },
    482 /*
    483 {
    484 	.name = "This does not compile",
    485 	.f = []() {
    486 		constexpr double N = 1.2;
    487 		Zmod<N> x;
    488 	}
    489 }
    490 */
    491 };
    492 
    493 int main() {
    494 	for (auto t : tests) {
    495 		std::cout << t.name << ": ";
    496 		t.f();
    497 		std::cout << "OK" << std::endl;
    498 	}
    499 	std::cout << "All tests passed" << std::endl;
    500 	return 0;
    501 }