commit 97c94018402c95f2ed1f59645f4b1cbdd255b978
parent 6cfb7b4d8e6b0838bfcde290e5fac18f46f18fb4
Author: Sebastiano Tronto <sebastiano@tronto.net>
Date: Tue, 21 Jan 2025 08:03:15 +0100
Final examples for templates
Diffstat:
5 files changed, 383 insertions(+), 2 deletions(-)
diff --git a/templates/bigint.h b/templates/bigint.h
@@ -0,0 +1,255 @@
+// Taken from https://git.tronto.net/zmodn
+
+#ifndef BIGUNSIGNED_H
+#define BIGUNSIGNED_H
+
+#include <cstdint>
+#include <iostream>
+#include <string_view>
+
+constexpr uint64_t abs64(int64_t);
+constexpr uint64_t pow10(uint64_t);
+
+// Big integer class for numbers of at most N decimal digits.
+// The number E is used to tune the size of each digit, mostly for
+// testing purposes.
+
+template<uint64_t N = 50, uint64_t E = 9>
+requires (E < 10)
+class BigInt {
+public:
+ // The member variables sign and digits are declared public so that
+ // BigInt becomes a structural type and can be used in templates.
+
+ static constexpr uint64_t M = pow10(E);
+ static constexpr uint64_t D = (N / E) + 1;
+
+ bool sign;
+ uint64_t digits[D];
+
+ constexpr BigInt() : sign{true} {
+ std::fill(digits, digits+D, 0);
+ }
+
+ constexpr BigInt(int64_t n) : sign{n >= 0} {
+ std::fill(digits, digits+D, 0);
+ digits[0] = abs64(n);
+ carryover();
+ }
+
+ constexpr BigInt(const std::string_view s) : sign{true} {
+ std::fill(digits, digits+D, 0);
+ if (s.size() == 0)
+ return;
+ for (int i = s.size()-1, j = 0; i >= 0; i--, j++) {
+ if (s[i] == '\'')
+ continue;
+ if (i == 0 && s[i] == '-') {
+ sign = false;
+ break;
+ }
+ digits[j/E] += (pow10(j % E))
+ * static_cast<uint64_t>(s[i] - '0');
+ }
+ }
+
+ constexpr auto operator<=>(const BigInt& other) const {
+ if (sign != other.sign)
+ return sign <=> other.sign;
+
+ for (int i = D-1; i >= 0; i--)
+ if (digits[i] != other.digits[i])
+ return sign ?
+ digits[i] <=> other.digits[i] :
+ other.digits[i] <=> digits[i];
+
+ return 0 <=> 0;
+ }
+
+ constexpr bool operator==(const BigInt& other) const = default;
+
+ constexpr BigInt abs() const {
+ BigInt ret = *this;
+ ret.sign = true;
+ return ret;
+ }
+
+ constexpr BigInt operator-() const {
+ if (*this == 0)
+ return 0;
+ BigInt ret = *this;
+ ret.sign = !ret.sign;
+ return ret;
+ }
+
+ constexpr BigInt operator+(const BigInt& z) const {
+ if (sign && z.sign)
+ return positive_sum(*this, z);
+ else if (sign && !z.sign)
+ return positive_diff(*this, -z);
+ else if (!sign && z.sign)
+ return positive_diff(z, -*this);
+ else
+ return -positive_sum(-*this, -z);
+ }
+
+ constexpr BigInt operator-(const BigInt& z) const {
+ return *this + (-z);
+ }
+
+ constexpr BigInt operator*(const BigInt& z) const {
+ BigInt ret;
+ ret.sign = !(sign ^ z.sign);
+ for (int i = 0; i < D; i++)
+ for (int j = 0; i+j < D; j++)
+ ret.digits[i+j] += digits[i] * z.digits[j];
+ ret.carryover();
+ return ret;
+ }
+
+ constexpr BigInt operator/(const BigInt& z) const {
+ auto [q, r] = euclidean_division(*this, z);
+ return q;
+ }
+
+ constexpr BigInt operator%(const BigInt& z) const {
+ auto [q, r] = euclidean_division(*this, z);
+ return r;
+ }
+
+ constexpr BigInt operator+=(const BigInt& z) { return *this = *this + z; }
+ constexpr BigInt operator++() { return *this += 1; }
+ constexpr BigInt operator-=(const BigInt& z) { return *this = *this - z; }
+ constexpr BigInt operator--() { return *this -= 1; }
+ constexpr BigInt operator*=(const BigInt& z) { return *this = *this * z; }
+ constexpr BigInt operator/=(const BigInt& z) { return *this = *this / z; }
+ constexpr BigInt operator%=(const BigInt& z) { return *this = *this % z; }
+
+ friend std::ostream& operator<<(std::ostream& os, const BigInt<N, E>& z) {
+ if (z == 0) {
+ os << "0";
+ return os;
+ }
+
+ if (!z.sign)
+ os << "-";
+
+ int j;
+ for (j = z.D-1; z.digits[j] == 0; j--) ;
+ os << z.digits[j]; // Top digit is not padded
+
+ for (int i = j-1; i >= 0; i--) {
+ std::string num = std::to_string(z.digits[i]);
+ os << std::string(E - num.length(), '0') << num;
+ }
+ return os;
+ }
+
+private:
+ constexpr void carryover() {
+ for (int i = 1; i < D; i++) {
+ auto c = digits[i-1] / M;
+ digits[i-1] -= c * M;
+ digits[i] += c;
+ }
+ }
+
+ constexpr BigInt half() const {
+ BigInt ret;
+ uint64_t carry = 0;
+ for (int i = D-1; i >= 0; i--) {
+ ret.digits[i] += (digits[i] + M * carry) / 2;
+ carry = digits[i] % 2;
+ }
+ return ret;
+ }
+
+ static constexpr BigInt powM(uint64_t e) {
+ BigInt ret;
+ ret.digits[e] = 1;
+ return ret;
+ }
+
+ // Sum of non-negative integers
+ static constexpr BigInt positive_sum(const BigInt& x, const BigInt& y) {
+ BigInt ret;
+ for (int i = 0; i < D; i++)
+ ret.digits[i] = x.digits[i] + y.digits[i];
+ ret.carryover();
+ return ret;
+ }
+
+ // Difference of non-negative integers (result may be negative)
+ static constexpr BigInt positive_diff(const BigInt& x, const BigInt& y) {
+ if (y > x)
+ return -positive_diff(y, x);
+
+ BigInt ret;
+ uint64_t carry = 0;
+ for (int i = 0; i < D; i++) {
+ uint64_t oldcarry = carry;
+ if (x.digits[i] < y.digits[i] + oldcarry) {
+ ret.digits[i] = M;
+ carry = 1;
+ } else {
+ carry = 0;
+ }
+ ret.digits[i] += x.digits[i];
+ ret.digits[i] -= y.digits[i] + oldcarry;
+ }
+ ret.carryover();
+ return ret;
+ }
+
+ // Division with remainder, UB if y == 0
+ static constexpr std::pair<BigInt, BigInt>
+ euclidean_division(const BigInt& x, const BigInt& y) {
+ auto [q, r] = positive_div(x.abs(), y.abs());
+ if (x.sign && y.sign)
+ return std::pair(q, r);
+ else if (x.sign && !y.sign)
+ return r == 0 ? std::pair(-q, 0) : std::pair(-q-1, y+r);
+ else if (!x.sign && y.sign)
+ return r == 0 ? std::pair(-q, r) : std::pair(-q-1, y-r);
+ else
+ return std::pair(q, -r);
+ }
+
+ // Division with remainder of non-negative integers, UB if y == 0
+ // This method is inefficient (O(log(x/y)) BigInt multiplications)
+ static constexpr std::pair<BigInt, BigInt>
+ positive_div(const BigInt& x, const BigInt& y) {
+ BigInt q = 0;
+ BigInt r = x;
+
+ if (y > x)
+ return std::pair(q, r);
+
+ BigInt lb = 0;
+ BigInt ub = x;
+ while (true) {
+ BigInt q = (ub + lb).half();
+ BigInt r = x - y*q;
+
+ if (r < 0)
+ ub = q;
+ else if (r >= y)
+ lb = q+1;
+ else
+ return std::pair(q, r);
+ }
+ }
+};
+
+constexpr uint64_t abs64(int64_t x) {
+ return static_cast<uint64_t>(x > 0 ? x : -x);
+}
+
+constexpr uint64_t pow10(uint64_t e) {
+ if (e == 0)
+ return 1;
+ else
+ return 10 * pow10(e-1);
+}
+
+#endif
diff --git a/templates/zmodn-1.cpp b/templates/zmodn-1.cpp
@@ -14,7 +14,6 @@ public:
int value;
Zmod(int z) : value{(z%N + N) % N} {}
- int toint() const { return value; }
Zmod operator+(const Zmod& z) const { return value + z.value; }
Zmod operator-(const Zmod& z) const { return value - z.value; }
diff --git a/templates/zmodn-2.cpp b/templates/zmodn-2.cpp
@@ -15,7 +15,6 @@ public:
int value;
Zmod(int z) : value{(z%N + N) % N} {}
- int toint() const { return value; }
Zmod operator+(const Zmod& z) const { return value + z.value; }
Zmod operator-(const Zmod& z) const { return value - z.value; }
diff --git a/templates/zmodn-3.cpp b/templates/zmodn-3.cpp
@@ -0,0 +1,57 @@
+#include "bigint.h"
+
+#include <iostream>
+#include <optional>
+#include <tuple>
+#include <type_traits>
+
+template<typename T>
+std::tuple<T, T, T> extended_gcd(T a, T b) {
+ if (b == 0) return {a, 1, 0};
+ auto [g, x, y] = extended_gcd(b, a%b);
+ return {g, y, x - y*(a/b)};
+}
+
+template<auto N>
+requires (N > 1)
+class Zmod {
+public:
+ decltype(N) value;
+
+ Zmod(decltype(N) z) : value{(z%N + N) % N} {}
+
+ Zmod operator+(const Zmod& z) const { return value + z.value; }
+ Zmod operator-(const Zmod& z) const { return value - z.value; }
+ Zmod operator*(const Zmod& z) const { return value * z.value; }
+
+ std::optional<Zmod> inverse() const {
+ auto [g, a, _] = extended_gcd(value, N);
+ return g == 1 ? Zmod(a) : std::optional<Zmod>{};
+ }
+
+ std::optional<Zmod> operator/(const Zmod& d) const {
+ auto i = d.inverse();
+ return i ? (*this) * i.value() : i;
+ }
+
+ std::optional<Zmod> operator/=(const Zmod& d) {
+ auto q = *this / d;
+ return q ? (*this = q.value()) : q;
+ }
+};
+
+int main() {
+ constexpr BigInt N("1000000000000000000000000000000");
+ Zmod<N> x(BigInt("123456781234567812345678"));
+ Zmod<N> y(BigInt("987654321987654321"));
+
+ std::cout << x.value << " * "
+ << y.value << " (mod " << N << ") = "
+ << (x * y).value << std::endl;
+
+ // The following gives a compile error on the first % operation
+ // constexpr double M = 3.14;
+ // Zmod<M> z(4);
+
+ return 0;
+}
diff --git a/templates/zmodn-4.cpp b/templates/zmodn-4.cpp
@@ -0,0 +1,71 @@
+#include "bigint.h"
+
+#include <iostream>
+#include <optional>
+#include <tuple>
+#include <type_traits>
+
+template<typename T>
+concept Integer = requires(T a, T b, int i) {
+ {T(i)};
+
+ {a + b} -> std::same_as<T>;
+ {a - b} -> std::same_as<T>;
+ {a * b} -> std::same_as<T>;
+ {a / b} -> std::same_as<T>;
+ {a % b} -> std::same_as<T>;
+
+ {a == b} -> std::same_as<bool>;
+ {a != b} -> std::same_as<bool>;
+};
+
+template<Integer T>
+std::tuple<T, T, T> extended_gcd(T a, T b) {
+ if (b == 0) return {a, 1, 0};
+ auto [g, x, y] = extended_gcd(b, a%b);
+ return {g, y, x - y*(a/b)};
+}
+
+template<Integer auto N>
+requires (N > 1)
+class Zmod {
+public:
+ decltype(N) value;
+
+ Zmod(decltype(N) z) : value{(z%N + N) % N} {}
+
+ Zmod operator+(const Zmod& z) const { return value + z.value; }
+ Zmod operator-(const Zmod& z) const { return value - z.value; }
+ Zmod operator*(const Zmod& z) const { return value * z.value; }
+
+ std::optional<Zmod> inverse() const {
+ auto [g, a, _] = extended_gcd(value, N);
+ return g == 1 ? Zmod(a) : std::optional<Zmod>{};
+ }
+
+ std::optional<Zmod> operator/(const Zmod& d) const {
+ auto i = d.inverse();
+ return i ? (*this) * i.value() : i;
+ }
+
+ std::optional<Zmod> operator/=(const Zmod& d) {
+ auto q = *this / d;
+ return q ? (*this = q.value()) : q;
+ }
+};
+
+int main() {
+ constexpr BigInt N("1000000000000000000000000000000");
+ Zmod<N> x(BigInt("123456781234567812345678"));
+ Zmod<N> y(BigInt("987654321987654321"));
+
+ std::cout << x.value << " * "
+ << y.value << " (mod " << N << ") = "
+ << (x * y).value << std::endl;
+
+ // The following line gives an error when trying to specialize Zmod<M>
+ // constexpr double M = 3.14;
+ // Zmod<M> z(4);
+
+ return 0;
+}