zmodn

A simple C++ library for integers modulo N
git clone https://git.tronto.net/zmodn
Download | Log | Files | Refs | README

commit 6eaa7b33aa8690f5ba4cee0897d2d05c71c27c20
Author: Sebastiano Tronto <sebastiano@tronto.net>
Date:   Sat, 21 Dec 2024 12:30:41 +0100

Initial commit

Diffstat:
AREADME.md | 10++++++++++
Atest | 128+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Azmodn.h | 71+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
3 files changed, 209 insertions(+), 0 deletions(-)

diff --git a/README.md b/README.md @@ -0,0 +1,10 @@ +# ZmodN - A simple library for modular arithmetic + +Usage: + +1. `#include "zmodn.h"` in your project +2. enjoy + +# Development + +Run `chmod +x test` and then `./test` to run tests. diff --git a/test b/test @@ -0,0 +1,128 @@ +#if 0 + +cc=${CC:-g++} +bin="$(mktemp)" +${cc} -x c++ -std=c++20 -o "$bin" "$(realpath $0)" +"$bin" + +exit 0 +#endif + +#include "zmodn.h" +#include <concepts> +#include <functional> +#include <iostream> +#include <optional> + +template<typename S, typename T> +requires std::convertible_to<S, T> || std::convertible_to<T, S> +void assert_equal(S actual, T expected) { + if (actual != expected) { + std::cout << "Error!" << std::endl; + std::cout << "Expected: " << expected << std::endl; + std::cout << "But got: " << actual << std::endl; + exit(1); + } +} + +class Test { +public: + std::string name; + std::function<void()> f; +} tests[] = { +{ + .name = "Constructor 2 mod 3", + .f = []() { + Zmod<3> two = Zmod<3>(2); + assert_equal(two.toint64(), INT64_C(2)); + } +}, +{ + .name = "Constructor -7 mod 3", + .f = []() { + Zmod<3> z = -7; + assert_equal(z, Zmod<3>(2)); + } +}, +{ + .name = "1+1 mod 2", + .f = []() { + auto oneplusone = Zmod<2>(1) + Zmod<2>(1); + assert_equal(oneplusone, Zmod<2>(0)); + } +}, +{ + .name = "2 -= 5 (mod 4)", + .f = []() { + Zmod<4> z = 2; + auto diff = (z -= 5); + assert_equal(z, Zmod<4>(1)); + assert_equal(diff, Zmod<4>(1)); + } +}, +{ + .name = "Inverse of 0 mod 2", + .f = []() { + Zmod<2> z = 0; + auto inv = z.inverse(); + assert_equal(inv.has_value(), false); + } +}, +{ + .name = "Inverse of 1 mod 2", + .f = []() { + Zmod<2> z = 1; + auto inv = z.inverse(); + assert_equal(inv.has_value(), true); + assert_equal(inv.value(), Zmod<2>(1)); + } +}, +{ + .name = "Inverse of 5 mod 7", + .f = []() { + Zmod<7> z = 5; + auto inv = z.inverse(); + assert_equal(inv.has_value(), true); + assert_equal(inv.value(), Zmod<7>(3)); + } +}, +{ + .name = "Inverse of 4 mod 12", + .f = []() { + Zmod<12> z = 4; + auto inv = z.inverse(); + assert_equal(inv.has_value(), false); + } +}, +{ + .name = "4 / 7 (mod 12)", + .f = []() { + Zmod<12> n = 4; + Zmod<12> d = 7; + auto inv = n / d; + assert_equal(inv.has_value(), true); + assert_equal(inv.value(), Zmod<12>(4)); + } +}, +{ + .name = "4 /= 7 (mod 12)", + .f = []() { + Zmod<12> n = 4; + Zmod<12> d = 7; + auto inv = (n /= d); + assert_equal(inv.has_value(), true); + assert_equal(inv.value(), Zmod<12>(4)); + assert_equal(n, Zmod<12>(4)); + } +}, +}; + +int main() { + for (auto t : tests) { + std::cout << t.name << ": "; + t.f(); + std::cout << "OK" << std::endl; + } + std::cout << "All tests passed" << std::endl; + return 0; +} diff --git a/zmodn.h b/zmodn.h @@ -0,0 +1,71 @@ +#ifndef ZMODN_H +#define ZMODN_H + +#include <cstdint> +#include <iostream> +#include <optional> +#include <tuple> + +std::tuple<int64_t, int64_t, int64_t> extended_gcd(int64_t, int64_t); + +template<int64_t N> requires(N > 1) +class Zmod { +public: + Zmod(int64_t z) : int64{(z%N + N) % N} {} + int64_t toint64() const { return int64; } + + Zmod operator+(const Zmod& z) const { return int64 + z.int64; } + Zmod operator-(const Zmod& z) const { return int64 - z.int64; } + Zmod operator*(const Zmod& z) const { return int64 * z.int64; } + Zmod operator+=(const Zmod& z) { return int64 += z.int64; } + Zmod operator-=(const Zmod& z) { return int64 -= z.int64; } + Zmod operator*=(const Zmod& z) { return int64 *= z.int64; } + + bool operator==(const Zmod& z) const { return int64 == z.int64; } + bool operator!=(const Zmod& z) const { return int64 != z.int64; } + + std::optional<Zmod> inverse() const { + auto [g, a, _] = extended_gcd(int64, N); + return g == 1 ? Zmod(a) : std::optional<Zmod>{}; + } + + std::optional<Zmod> operator/(const Zmod& d) const { + auto i = d.inverse(); + return i ? (*this) * i.value() : i; + } + + std::optional<Zmod> operator/=(const Zmod& d) { + auto q = *this / d; + return q ? (*this = q.value()) : q; + } + + friend std::ostream& operator<<(std::ostream& os, const Zmod<N>& z) { + return os << "(" << z.int64 << " mod " << N << ")"; + } +private: + int64_t int64; +}; + +void swapdiv(int64_t& oldx, int64_t& x, int64_t q) { + int64_t tmp = x; + x = oldx - q*tmp; + oldx = tmp; +} + +std::tuple<int64_t, int64_t, int64_t> extended_gcd(int64_t a, int64_t b) { + int64_t oldr = a; + int64_t r = b; + int64_t olds = 1; + int64_t s = 0; + int64_t oldt = 0; + int64_t t = 1; + while (r != 0) { + auto q = oldr / r; + swapdiv(oldr, r, q); + swapdiv(olds, s, q); + swapdiv(oldt, t, q); + } + return {oldr, olds, oldt}; +} + +#endif