zmodn

A simple C++ library for integers modulo N
git clone https://git.tronto.net/zmodn
Download | Log | Files | Refs | README

test (9023B)


      1 #if 0
      2 
      3 cc=${CC:-g++}
      4 bin="$(mktemp)"
      5 ${cc} -x c++ -std=c++20 -o "$bin" -g -O0 "$(realpath $0)"
      6 echo "Running $bin"
      7 "$bin"
      8 
      9 exit 0
     10 #endif
     11 
     12 #include "zmodn.h"
     13 #include "bigint.h"
     14 
     15 #if defined(__clang__) && __clang_major__ >= 16
     16 #include "bitint_wrapper.h"
     17 #endif
     18 
     19 #include <concepts>
     20 #include <functional>
     21 #include <iostream>
     22 #include <optional>
     23 
     24 template<typename S, typename T>
     25 requires std::convertible_to<S, T> || std::convertible_to<T, S>
     26 void assert_equal(S actual, T expected) {
     27 	if (actual != expected) {
     28 		std::cout << "Error!" << std::endl;
     29 		std::cout << "Expected: " << expected << std::endl;
     30 		std::cout << "But got:  " << actual << std::endl;
     31 		exit(1);
     32 	}
     33 }
     34 
     35 class Test {
     36 public:
     37 	std::string name;
     38 	std::function<void()> f;
     39 } tests[] = {
     40 {
     41 	.name = "Constructor 2 mod 3",
     42 	.f = []() {
     43 		Zmod<3> two = Zmod<3>(2);
     44 		assert_equal(two.toint(), INT64_C(2));
     45 	}
     46 },
     47 {
     48 	.name = "Constructor -7 mod 3",
     49 	.f = []() {
     50 		Zmod<3> z = -7;
     51 		assert_equal(z, Zmod<3>(2));
     52 	}
     53 },
     54 {
     55 	.name = "1+1 mod 2",
     56 	.f = []() {
     57 		auto oneplusone = Zmod<2>(1) + Zmod<2>(1);
     58 		assert_equal(oneplusone, Zmod<2>(0));
     59 	}
     60 },
     61 {
     62 	.name = "2 -= 5 (mod 4)",
     63 	.f = []() {
     64 		Zmod<4> z = 2;
     65 		auto diff = (z -= 5);
     66 		assert_equal(z, Zmod<4>(1));
     67 		assert_equal(diff, Zmod<4>(1));
     68 	}
     69 },
     70 {
     71 	.name = "Inverse of 0 mod 2",
     72 	.f = []() {
     73 		Zmod<2> z = 0;
     74 		auto inv = z.inverse();
     75 		assert_equal(inv.has_value(), false);
     76 	}
     77 },
     78 {
     79 	.name = "Inverse of 1 mod 2",
     80 	.f = []() {
     81 		Zmod<2> z = 1;
     82 		auto inv = z.inverse();
     83 		assert_equal(inv.has_value(), true);
     84 		assert_equal(inv.value(), Zmod<2>(1));
     85 	}
     86 },
     87 {
     88 	.name = "Inverse of 5 mod 7",
     89 	.f = []() {
     90 		Zmod<7> z = 5;
     91 		auto inv = z.inverse();
     92 		assert_equal(inv.has_value(), true);
     93 		assert_equal(inv.value(), Zmod<7>(3));
     94 	}
     95 },
     96 {
     97 	.name = "Inverse of 4 mod 12",
     98 	.f = []() {
     99 		Zmod<12> z = 4;
    100 		auto inv = z.inverse();
    101 		assert_equal(inv.has_value(), false);
    102 	}
    103 },
    104 {
    105 	.name = "4 / 7 (mod 12)",
    106 	.f = []() {
    107 		Zmod<12> n = 4;
    108 		Zmod<12> d = 7;
    109 		auto inv = n / d;
    110 		assert_equal(inv.has_value(), true);
    111 		assert_equal(inv.value(), Zmod<12>(4));
    112 	}
    113 },
    114 {
    115 	.name = "4 /= 7 (mod 12)",
    116 	.f = []() {
    117 		Zmod<12> n = 4;
    118 		Zmod<12> d = 7;
    119 		auto inv = (n /= d);
    120 		assert_equal(inv.has_value(), true);
    121 		assert_equal(inv.value(), Zmod<12>(4));
    122 		assert_equal(n, Zmod<12>(4));
    123 	}
    124 },
    125 {
    126 	.name = "Multiplication overflow",
    127 	.f = []() {
    128 		Zmod<10> n = 8;
    129 		Zmod<10> m = 9;
    130 		auto prod = m * n;
    131 		assert_equal(prod.toint(), 2);
    132 	}
    133 },
    134 {
    135 	.name = "Multiplication and assignment overflow",
    136 	.f = []() {
    137 		Zmod<10> n = 8;
    138 		Zmod<10> m = 9;
    139 		n *= m;
    140 		assert_equal(n.toint(), 2);
    141 	}
    142 },
    143 {
    144 	.name = "2^12 mod 154",
    145 	.f = []() {
    146 		Zmod<154> b = 2;
    147 		auto p = b ^ 12;
    148 		assert_equal(p, 92);
    149 	}
    150 },
    151 {
    152 	.name = "BigInt constructor zero",
    153 	.f = []() {
    154 		BigInt x;
    155 		BigInt y(0);
    156 
    157 		assert_equal(x, y);
    158 	}
    159 },
    160 {
    161 	.name = "BigInt constructor one digit",
    162 	.f = []() {
    163 		BigInt x(12345);
    164 		BigInt y("12345");
    165 
    166 		assert_equal(x, y);
    167 	}
    168 },
    169 {
    170 	.name = "BigInt constructor many small digits",
    171 	.f = []() {
    172 		BigInt<20, 2> x(123456789);
    173 		BigInt<20, 2> y("123456789");
    174 
    175 		assert_equal(x, y);
    176 	}
    177 },
    178 {
    179 	.name = "BigInt constructor negative many small digits",
    180 	.f = []() {
    181 		BigInt<20, 2> x(-123456789);
    182 		BigInt<20, 2> y("-123456789");
    183 
    184 		assert_equal(x, y);
    185 	}
    186 },
    187 {
    188 	.name = "BigInt operator==",
    189 	.f = []() {
    190 		BigInt<20, 2> x(123456789);
    191 		BigInt<20, 2> y("123456789");
    192 		BigInt<20, 2> z("12456789");
    193 
    194 		bool eq = (x == y);
    195 		bool diff = (x == z);
    196 
    197 		assert_equal(eq, true);
    198 		assert_equal(diff, false);
    199 	},
    200 },
    201 {
    202 	.name = "BigInt operator== negative",
    203 	.f = []() {
    204 		BigInt<20, 2> x("-123456789");
    205 		BigInt<20, 2> z("123456789");
    206 
    207 		bool diff = (x == z);
    208 
    209 		assert_equal(diff, false);
    210 	},
    211 },
    212 {
    213 	.name = "BigInt operator!= true",
    214 	.f = []() {
    215 		BigInt<20, 2> x(12345678);
    216 		BigInt<20, 2> y("123456789");
    217 		BigInt<20, 2> z("123456789");
    218 
    219 		bool diff = (x != y);
    220 		bool eq = (y != z);
    221 
    222 		assert_equal(diff, true);
    223 		assert_equal(eq, false);
    224 	},
    225 },
    226 {
    227 	.name = "BigInt operator< and operator>",
    228 	.f = []() {
    229 		BigInt<20, 2> x(7891);
    230 		BigInt<20, 2> y(7881);
    231 	
    232 		bool t = (y < x);
    233 		bool f = (x < y);
    234 
    235 		assert_equal(t, true);
    236 		assert_equal(f, false);
    237 	}
    238 },
    239 {
    240 	.name = "BigInt operator< both negative",
    241 	.f = []() {
    242 		BigInt<20, 2> x(-7891);
    243 		BigInt<20, 2> y(-7881);
    244 	
    245 		bool cmp = (x < y);
    246 
    247 		assert_equal(cmp, true);
    248 	}
    249 },
    250 {
    251 	.name = "BigInt operator< different sign",
    252 	.f = []() {
    253 		BigInt<20, 2> x(-7);
    254 		BigInt<20, 2> y(7);
    255 	
    256 		bool cmp = (x < y);
    257 
    258 		assert_equal(cmp, true);
    259 	}
    260 },
    261 {
    262 	.name = "BigInt abs",
    263 	.f = []() {
    264 		BigInt<20, 2> x(-1234567);
    265 		BigInt<20, 2> y(7654321);
    266 
    267 		assert_equal(x.abs(), BigInt<20, 2>(1234567));
    268 		assert_equal(y.abs(), y);
    269 	}
    270 },
    271 {
    272 	.name = "BigInt opposite",
    273 	.f = []() {
    274 		BigInt<20, 2> x(-1234567);
    275 		BigInt<20, 2> y(7654321);
    276 
    277 		assert_equal(-x, BigInt<20, 2>(1234567));
    278 		assert_equal(-y, BigInt<20, 2>(-7654321));
    279 	}
    280 },
    281 {
    282 	.name = "BigInt -0 == 0",
    283 	.f = []() {
    284 		BigInt z(0);
    285 
    286 		assert_equal(-z, z);
    287 	}
    288 },
    289 {
    290 	.name = "BigInt sum",
    291 	.f = []() {
    292 		BigInt<20, 2> x("987608548588589");
    293 		BigInt<20, 2> y("6793564545455289");
    294 		BigInt<20, 2> z("7781173094043878");
    295 
    296 		assert_equal(x+y, z);
    297 	}
    298 },
    299 {
    300 	.name = "BigInt sum both negative",
    301 	.f = []() {
    302 		BigInt<20, 2> x("-987608548588589");
    303 		BigInt<20, 2> y("-6793564545455289");
    304 		BigInt<20, 2> z("-7781173094043878");
    305 
    306 		assert_equal(x+y, z);
    307 	}
    308 },
    309 {
    310 	.name = "BigInt sum negative + positive, result positive",
    311 	.f = []() {
    312 		BigInt<20, 2> x("-987608548588589");
    313 		BigInt<20, 2> y("6793564545455289");
    314 		BigInt<20, 2> z("5805955996866700");
    315 
    316 		assert_equal(x+y, z);
    317 	}
    318 },
    319 {
    320 	.name = "BigInt sum positive + negative, result negative",
    321 	.f = []() {
    322 		BigInt<20, 2> x("987608548588589");
    323 		BigInt<20, 2> y("-6793564545455289");
    324 		BigInt<20, 2> z("-5805955996866700");
    325 
    326 		assert_equal(x+y, z);
    327 	}
    328 },
    329 {
    330 	.name = "BigInt difference",
    331 	.f = []() {
    332 		BigInt<20, 2> x("2342442323434134");
    333 		BigInt<20, 2> y("2524342523342342");
    334 		BigInt<20, 2> z("-181900199908208");
    335 
    336 		assert_equal(x-y, z);
    337 	}
    338 },
    339 {
    340 	.name = "BigInt product",
    341 	.f = []() {
    342 		BigInt<100, 3> x("134142345244134");
    343 		BigInt<100, 3> y("-56543047058245");
    344 		BigInt<100, 3> z("-7584816939642416135042584830");
    345 
    346 		assert_equal(x*y, z);
    347 	}
    348 },
    349 {
    350 	.name = "BigInt operator+=",
    351 	.f = []() {
    352 		BigInt<20, 2> x("987608548588589");
    353 		BigInt<20, 2> y("6793564545455289");
    354 		BigInt<20, 2> z("7781173094043878");
    355 
    356 		x += y;
    357 
    358 		assert_equal(x, z);
    359 	}
    360 },
    361 {
    362 	.name = "BigInt 14 / 3",
    363 	.f = []() {
    364 		BigInt x(14);
    365 		BigInt y(3);
    366 
    367 		assert_equal(x / y, 4);
    368 	}
    369 },
    370 {
    371 	.name = "BigInt 14 % 3",
    372 	.f = []() {
    373 		BigInt x(14);
    374 		BigInt y(3);
    375 
    376 		assert_equal(x % y, 2);
    377 	}
    378 },
    379 {
    380 	.name = "BigInt 14 / -3",
    381 	.f = []() {
    382 		BigInt x(14);
    383 		BigInt y(-3);
    384 
    385 		assert_equal(x / y, -5);
    386 	}
    387 },
    388 {
    389 	.name = "BigInt 14 % -3",
    390 	.f = []() {
    391 		BigInt x(14);
    392 		BigInt y(-3);
    393 
    394 		assert_equal(x % y, -1);
    395 	}
    396 },
    397 {
    398 	.name = "BigInt -14 / 3",
    399 	.f = []() {
    400 		BigInt x(-14);
    401 		BigInt y(3);
    402 
    403 		assert_equal(x / y, -5);
    404 	}
    405 },
    406 {
    407 	.name = "BigInt -14 % 3",
    408 	.f = []() {
    409 		BigInt x(-14);
    410 		BigInt y(3);
    411 
    412 		assert_equal(x % y, 1);
    413 	}
    414 },
    415 {
    416 	.name = "BigInt -14 / -3",
    417 	.f = []() {
    418 		BigInt x(-14);
    419 		BigInt y(-3);
    420 
    421 		assert_equal(x / y, 4);
    422 	}
    423 },
    424 {
    425 	.name = "BigInt -14 % -3",
    426 	.f = []() {
    427 		BigInt x(-14);
    428 		BigInt y(-3);
    429 
    430 		assert_equal(x % y, -2);
    431 	}
    432 },
    433 {
    434 	.name = "BigInt division large numbers, quotient = 0",
    435 	.f = []() {
    436 		BigInt<50, 3> x("4534435234134244242");
    437 		BigInt<50, 3> y("7832478748237487343");
    438 
    439 		assert_equal(x / y, 0);
    440 	}
    441 },
    442 {
    443 	.name = "BigInt division large numbers",
    444 	.f = []() {
    445 		BigInt<50, 3> x("12344534435234134244242");
    446 		BigInt<50, 3> y("7832478748237487343");
    447 		BigInt<50, 3> z(1576);
    448 
    449 		assert_equal(x / y, z);
    450 	}
    451 },
    452 {
    453 	.name = "BigInt modulo large numbers",
    454 	.f = []() {
    455 		BigInt<50, 3> x("12344534435234134244242");
    456 		BigInt<50, 3> y("7832478748237487343");
    457 		BigInt<50, 3> z("547928011854191674");
    458 
    459 		assert_equal(x % y, z);
    460 	}
    461 },
    462 {
    463 	.name = "Zmod with BigInt constructor",
    464 	.f = []() {
    465 		constexpr BigInt<50, 3> N("78923471");
    466 		constexpr BigInt<50, 3> x("145452451");
    467 		Zmod<N> xmodN(x);
    468 
    469 		assert_equal(xmodN.toint(), x % N);
    470 	}
    471 },
    472 {
    473 	.name = "Zmod with BigInt big inverse",
    474 	.f = []() {
    475 		constexpr BigInt<50, 3> N("7520824651249795349285");
    476 		constexpr BigInt<50, 3> x("234589234599896924596");
    477 		constexpr BigInt<50, 3> expected("5901181270843786267351");
    478 		Zmod<N> xmodN(x);
    479 
    480 		auto inv = xmodN.inverse();
    481 
    482 		assert_equal(inv.has_value(), true);
    483 		assert_equal(inv.value().toint(), expected);
    484 	}
    485 },
    486 #if defined(__clang__) && __clang_major__ >= 16
    487 {
    488 	.name = "Zmod with BitInt",
    489 	.f = []() {
    490 		constexpr BitInt N{1000000000000LL}; // 1e12
    491 		BitInt a{600000000000LL}; // 6e11
    492 		BitInt b{700000000000LL}; // 7e11
    493 		BitInt c{300000000000LL}; // 3e11
    494 
    495 		Zmod<N> a_modN(a), b_modN(b), c_modN(c);
    496 		assert_equal(a_modN + b_modN, c_modN);
    497 	}
    498 },
    499 #endif
    500 /*
    501 {
    502 	.name = "This does not compile",
    503 	.f = []() {
    504 		constexpr double N = 1.2;
    505 		Zmod<N> x;
    506 	}
    507 }
    508 */
    509 };
    510 
    511 int main() {
    512 	for (auto t : tests) {
    513 		std::cout << t.name << ": ";
    514 		t.f();
    515 		std::cout << "OK" << std::endl;
    516 	}
    517 	std::cout << "All tests passed" << std::endl;
    518 	return 0;
    519 }