zmodn

A simple C++ library for integers modulo N
git clone https://git.tronto.net/zmodn
Download | Log | Files | Refs | README

bigint.h (6232B)


      1 #ifndef BIGUNSIGNED_H
      2 #define BIGUNSIGNED_H
      3 
      4 #include <cstdint>
      5 #include <iostream>
      6 #include <random>
      7 #include <string_view>
      8 
      9 constexpr uint64_t abs64(int64_t);
     10 constexpr uint64_t pow10(uint64_t);
     11 
     12 // Big integer class for numbers of at most N decimal digits.
     13 // The number E is used to tune the size of each digit, mostly for
     14 // testing purposes.
     15 
     16 template<uint64_t N = 50, uint64_t E = 9>
     17 requires (E < 10)
     18 class BigInt {
     19 public:
     20 	// The member variables sign and digits are declared public so that
     21 	// BigInt becomes a structural type and can be used in templates.
     22 
     23 	static constexpr uint64_t M = pow10(E);
     24 	static constexpr uint64_t D = (N / E) + 1;
     25 
     26 	bool sign;
     27 	uint64_t digits[D];
     28 
     29 	constexpr BigInt() : sign{true} {
     30 		std::fill(digits, digits+D, 0);
     31 	}
     32 
     33 	constexpr BigInt(int64_t n) : sign{n >= 0} {
     34 		std::fill(digits, digits+D, 0);
     35 		digits[0] = abs64(n);
     36 		carryover();
     37 	}
     38 
     39 	constexpr BigInt(const std::string_view s) : sign{true} {
     40 		std::fill(digits, digits+D, 0);
     41 		if (s.size() == 0)
     42 			return;
     43 		for (int i = s.size()-1, j = 0; i >= 0; i--, j++) {
     44 			if (s[i] == '\'')
     45 				continue;
     46 			if (i == 0 && s[i] == '-') {
     47 				sign = false;
     48 				break;
     49 			}
     50 			digits[j/E] += (pow10(j % E))
     51 			    * static_cast<uint64_t>(s[i] - '0');
     52 		}
     53 	}
     54 
     55 	constexpr auto operator<=>(const BigInt& other) const {
     56 		if (sign != other.sign)
     57 			return sign <=> other.sign;
     58 
     59 		for (int i = D-1; i >= 0; i--)
     60 			if (digits[i] != other.digits[i])
     61 				return sign ?
     62 				    digits[i] <=> other.digits[i] :
     63 				    other.digits[i] <=> digits[i];
     64 
     65 		return 0 <=> 0;
     66 	}
     67 
     68 	constexpr bool operator==(const BigInt& other) const = default;
     69 
     70 	constexpr BigInt abs() const {
     71 		BigInt ret = *this;
     72 		ret.sign = true;
     73 		return ret;
     74 	}
     75 
     76 	constexpr BigInt operator-() const {
     77 		if (*this == 0)
     78 			return 0;
     79 		BigInt ret = *this;
     80 		ret.sign = !ret.sign;
     81 		return ret;
     82 	}
     83 
     84 	constexpr BigInt operator+(const BigInt& z) const {
     85 		if (sign && z.sign)
     86 			return positive_sum(*this, z);
     87 		else if (sign && !z.sign)
     88 			return positive_diff(*this, -z);
     89 		else if (!sign && z.sign)
     90 			return positive_diff(z, -*this);
     91 		else
     92 			return -positive_sum(-*this, -z);
     93 	}
     94 
     95 	constexpr BigInt operator-(const BigInt& z) const {
     96 		return *this + (-z);
     97 	}
     98 
     99 	constexpr BigInt operator*(const BigInt& z) const {
    100 		BigInt ret;
    101 		ret.sign = !(sign ^ z.sign);
    102 		for (int i = 0; i < D; i++)
    103 			for (int j = 0; i+j < D; j++)
    104 				ret.digits[i+j] += digits[i] * z.digits[j];
    105 		ret.carryover();
    106 		return ret;
    107 	}
    108 
    109 	constexpr BigInt operator/(const BigInt& z) const {
    110 		auto [q, r] = euclidean_division(*this, z);
    111 		return q;
    112 	}
    113 
    114 	constexpr BigInt operator%(const BigInt& z) const {
    115 		auto [q, r] = euclidean_division(*this, z);
    116 		return r;
    117 	}
    118 
    119 	constexpr BigInt operator+=(const BigInt& z) { return *this = *this + z; }
    120 	constexpr BigInt operator++() { return *this += 1; }
    121 	constexpr BigInt operator-=(const BigInt& z) { return *this = *this - z; }
    122 	constexpr BigInt operator--() { return *this -= 1; }
    123 	constexpr BigInt operator*=(const BigInt& z) { return *this = *this * z; }
    124 	constexpr BigInt operator/=(const BigInt& z) { return *this = *this / z; }
    125 	constexpr BigInt operator%=(const BigInt& z) { return *this = *this % z; }
    126 
    127 	static BigInt random(BigInt r) {
    128 		std::random_device rd;
    129 		std::default_random_engine rng(rd());
    130 		std::uniform_int_distribution<int> distribution(0, M-1);
    131 
    132 		BigInt ret;
    133 		for (uint64_t i = 0; i < D; i++)
    134 			ret.digits[i] = distribution(rng);
    135 
    136 		return ret % r;
    137 	}
    138 
    139 	friend std::ostream& operator<<(std::ostream& os, const BigInt<N, E>& z) {
    140 		if (z == 0) {
    141 			os << "0";
    142 			return os;
    143 		}
    144 
    145 		if (!z.sign)
    146 			os << "-";
    147 
    148 		int j;
    149 		for (j = z.D-1; z.digits[j] == 0; j--) ;
    150 		os << z.digits[j]; // Top digit is not padded
    151 
    152 		for (int i = j-1; i >= 0; i--) {
    153 			std::string num = std::to_string(z.digits[i]);
    154 			os << std::string(E - num.length(), '0') << num;
    155 		}
    156 		return os;
    157 	}
    158 
    159 private:
    160 	constexpr void carryover() {
    161 		for (int i = 1; i < D; i++) {
    162 			auto c = digits[i-1] / M;
    163 			digits[i-1] -= c * M;
    164 			digits[i] += c;
    165 		}
    166 	}
    167 
    168 	constexpr BigInt half() const {
    169 		BigInt ret;
    170 		uint64_t carry = 0;
    171 		for (int i = D-1; i >= 0; i--) {
    172 			ret.digits[i] += (digits[i] + M * carry) / 2;
    173 			carry = digits[i] % 2;
    174 		}
    175 		return ret;
    176 	}
    177 
    178 	static constexpr BigInt powM(uint64_t e) {
    179 		BigInt ret;
    180 		ret.digits[e] = 1;
    181 		return ret;
    182 	}
    183 
    184 	// Sum of non-negative integers
    185 	static constexpr BigInt positive_sum(const BigInt& x, const BigInt& y) {
    186 		BigInt ret;
    187 		for (int i = 0; i < D; i++)
    188 			ret.digits[i] = x.digits[i] + y.digits[i];
    189 		ret.carryover();
    190 		return ret;
    191 	}
    192 
    193 	// Difference of non-negative integers (result may be negative)
    194 	static constexpr BigInt positive_diff(const BigInt& x, const BigInt& y) {
    195 		if (y > x)
    196 			return -positive_diff(y, x);
    197 
    198 		BigInt ret;
    199 		uint64_t carry = 0;
    200 		for (int i = 0; i < D; i++) {
    201 			uint64_t oldcarry = carry;
    202 			if (x.digits[i] < y.digits[i] + oldcarry) {
    203 				ret.digits[i] = M;
    204 				carry = 1;
    205 			} else {
    206 				carry = 0;
    207 			}
    208 			ret.digits[i] += x.digits[i];
    209 			ret.digits[i] -= y.digits[i] + oldcarry;
    210 		}
    211 		ret.carryover();
    212 		return ret;
    213 	}
    214 
    215 	// Division with remainder, UB if y == 0
    216 	static constexpr std::pair<BigInt, BigInt>
    217 	euclidean_division(const BigInt& x, const BigInt& y) {
    218 		auto [q, r] = positive_div(x.abs(), y.abs());
    219 		if (x.sign && y.sign)
    220 			return std::pair(q, r);
    221 		else if (x.sign && !y.sign)
    222 			return r == 0 ? std::pair(-q, 0) : std::pair(-q-1, y+r);
    223 		else if (!x.sign && y.sign)
    224 			return r == 0 ? std::pair(-q, r) : std::pair(-q-1, y-r);
    225 		else
    226 			return std::pair(q, -r);
    227 	}
    228 
    229 	// Division with remainder of non-negative integers, UB if y == 0
    230 	// This method is inefficient (O(log(x/y)) BigInt multiplications)
    231 	static constexpr std::pair<BigInt, BigInt>
    232 	positive_div(const BigInt& x, const BigInt& y) {
    233 		BigInt q = 0;
    234 		BigInt r = x;
    235 
    236 		if (y > x)
    237 			return std::pair(q, r);
    238 
    239 		BigInt lb = 0;
    240 		BigInt ub = x;
    241 		while (true) {
    242 			BigInt q = (ub + lb).half();
    243 			BigInt r = x - y*q;
    244 
    245 			if (r < 0)
    246 				ub = q;
    247 			else if (r >= y)
    248 				lb = q+1;
    249 			else
    250 				return std::pair(q, r);
    251 		}
    252 	}
    253 };
    254 
    255 constexpr uint64_t abs64(int64_t x) {
    256 	return static_cast<uint64_t>(x > 0 ? x : -x);
    257 }
    258 
    259 constexpr uint64_t pow10(uint64_t e) {
    260 	if (e == 0)
    261 		return 1;
    262 	else
    263 		return 10 * pow10(e-1);
    264 }
    265 
    266 #endif