blob: f4c2b5e6e6f240c9ea59b99e32301498571df9de [file] [log] [blame]
Damien George438c88d2014-02-22 19:25:23 +00001#include <stdint.h>
2#include <stdbool.h>
3#include <stdlib.h>
4#include <string.h>
5#include <assert.h>
6
7#include "misc.h"
8#include "mpconfig.h"
9#include "mpz.h"
10
11#if MICROPY_LONGINT_IMPL == MICROPY_LONGINT_IMPL_MPZ
12
13#define DIG_SIZE (15)
14#define DIG_MASK ((1 << DIG_SIZE) - 1)
15
16/*
17 definition of normalise:
18 ?
19*/
20
21/* compares i with j
22 returns sign(i - j)
23 assumes i, j are normalised
24*/
25int mpn_cmp(const mpz_dig_t *idig, uint ilen, const mpz_dig_t *jdig, uint jlen) {
26 if (ilen < jlen) { return -1; }
27 if (ilen > jlen) { return 1; }
28
29 for (idig += ilen, jdig += ilen; ilen > 0; --ilen) {
30 int cmp = *(--idig) - *(--jdig);
31 if (cmp < 0) { return -1; }
32 if (cmp > 0) { return 1; }
33 }
34
35 return 0;
36}
37
38/* computes i = j >> n
39 returns number of digits in i
40 assumes enough memory in i; assumes normalised j
41 can have i, j pointing to same memory
42*/
43uint mpn_shr(mpz_dig_t *idig, mpz_dig_t *jdig, uint jlen, uint n) {
44 uint n_whole = n / DIG_SIZE;
45 uint n_part = n % DIG_SIZE;
46
47 if (n_whole >= jlen) {
48 return 0;
49 }
50
51 jdig += n_whole;
52 jlen -= n_whole;
53
54 for (uint i = jlen; i > 0; --i, ++idig, ++jdig) {
55 mpz_dbl_dig_t d = *jdig;
56 if (i > 1)
57 d |= jdig[1] << DIG_SIZE;
58 d >>= n_part;
59 *idig = d & DIG_MASK;
60 }
61
62 if (idig[-1] == 0) {
63 --jlen;
64 }
65
66 return jlen;
67}
68
69/* computes i = j + k
70 returns number of digits in i
71 assumes enough memory in i; assumes normalised j, k; assumes jlen >= klen
72 can have i, j, k pointing to same memory
73*/
74uint mpn_add(mpz_dig_t *idig, const mpz_dig_t *jdig, uint jlen, const mpz_dig_t *kdig, uint klen) {
75 mpz_dig_t *oidig = idig;
76 mpz_dbl_dig_t carry = 0;
77
78 jlen -= klen;
79
80 for (; klen > 0; --klen, ++idig, ++jdig, ++kdig) {
81 carry += *jdig + *kdig;
82 *idig = carry & DIG_MASK;
83 carry >>= DIG_SIZE;
84 }
85
86 for (; jlen > 0; --jlen, ++idig, ++jdig) {
87 carry += *jdig;
88 *idig = carry & DIG_MASK;
89 carry >>= DIG_SIZE;
90 }
91
92 if (carry != 0) {
93 *idig++ = carry;
94 }
95
96 return idig - oidig;
97}
98
99/* computes i = j - k
100 returns number of digits in i
101 assumes enough memory in i; assumes normalised j, k; assumes j >= k
102 can have i, j, k pointing to same memory
103*/
104uint mpn_sub(mpz_dig_t *idig, const mpz_dig_t *jdig, uint jlen, const mpz_dig_t *kdig, uint klen) {
105 mpz_dig_t *oidig = idig;
106 mpz_dbl_dig_signed_t borrow = 0;
107
108 jlen -= klen;
109
110 for (; klen > 0; --klen, ++idig, ++jdig, ++kdig) {
111 borrow += *jdig - *kdig;
112 *idig = borrow & DIG_MASK;
113 borrow >>= DIG_SIZE;
114 }
115
Damien Georgeaca14122014-02-24 21:32:52 +0000116 for (; jlen > 0; --jlen, ++idig, ++jdig) {
Damien George438c88d2014-02-22 19:25:23 +0000117 borrow += *jdig;
118 *idig = borrow & DIG_MASK;
119 borrow >>= DIG_SIZE;
120 }
121
122 for (--idig; idig >= oidig && *idig == 0; --idig) {
123 }
124
125 return idig + 1 - oidig;
126}
127
128/* computes i = i * d1 + d2
129 returns number of digits in i
130 assumes enough memory in i; assumes normalised i; assumes dmul != 0
131*/
132uint mpn_mul_dig_add_dig(mpz_dig_t *idig, uint ilen, mpz_dig_t dmul, mpz_dig_t dadd) {
133 mpz_dig_t *oidig = idig;
134 mpz_dbl_dig_t carry = dadd;
135
136 for (; ilen > 0; --ilen, ++idig) {
137 carry += *idig * dmul; // will never overflow so long as DIG_SIZE <= WORD_SIZE / 2
138 *idig = carry & DIG_MASK;
139 carry >>= DIG_SIZE;
140 }
141
142 if (carry != 0) {
143 *idig++ = carry;
144 }
145
146 return idig - oidig;
147}
148
149/* computes i = j * k
150 returns number of digits in i
151 assumes enough memory in i; assumes i is zeroed; assumes normalised j, k
152 can have j, k point to same memory
153*/
154uint mpn_mul(mpz_dig_t *idig, mpz_dig_t *jdig, uint jlen, mpz_dig_t *kdig, uint klen) {
155 mpz_dig_t *oidig = idig;
156 uint ilen = 0;
157
158 for (; klen > 0; --klen, ++idig, ++kdig) {
159 mpz_dig_t *id = idig;
160 mpz_dbl_dig_t carry = 0;
161
162 uint jl = jlen;
163 for (mpz_dig_t *jd = jdig; jl > 0; --jl, ++jd, ++id) {
164 carry += *id + *jd * *kdig; // will never overflow so long as DIG_SIZE <= WORD_SIZE / 2
165 *id = carry & DIG_MASK;
166 carry >>= DIG_SIZE;
167 }
168
169 if (carry != 0) {
170 *id++ = carry;
171 }
172
173 ilen = id - oidig;
174 }
175
176 return ilen;
177}
178
179/* natural_div - quo * den + new_num = old_num (ie num is replaced with rem)
180 assumes den != 0
181 assumes num_dig has enough memory to be extended by 1 digit
182 assumes quo_dig has enough memory (as many digits as num)
183 assumes quo_dig is filled with zeros
184 modifies den_dig memory, but restors it to original state at end
185*/
186
187void mpn_div(mpz_dig_t *num_dig, machine_uint_t *num_len, mpz_dig_t *den_dig, machine_uint_t den_len, mpz_dig_t *quo_dig, machine_uint_t *quo_len) {
188 mpz_dig_t *orig_num_dig = num_dig;
189 mpz_dig_t *orig_quo_dig = quo_dig;
190 mpz_dig_t norm_shift = 0;
191 mpz_dbl_dig_t lead_den_digit;
192
193 // handle simple cases
194 {
195 int cmp = mpn_cmp(num_dig, *num_len, den_dig, den_len);
196 if (cmp == 0) {
197 *num_len = 0;
198 quo_dig[0] = 1;
199 *quo_len = 1;
200 return;
201 } else if (cmp < 0) {
202 // numerator remains the same
203 *quo_len = 0;
204 return;
205 }
206 }
207
208 // count number of leading zeros in leading digit of denominator
209 {
210 mpz_dig_t d = den_dig[den_len - 1];
211 while ((d & (1 << (DIG_SIZE - 1))) == 0) {
212 d <<= 1;
213 ++norm_shift;
214 }
215 }
216
217 // normalise denomenator (leading bit of leading digit is 1)
218 for (mpz_dig_t *den = den_dig, carry = 0; den < den_dig + den_len; ++den) {
219 mpz_dig_t d = *den;
220 *den = ((d << norm_shift) | carry) & DIG_MASK;
221 carry = d >> (DIG_SIZE - norm_shift);
222 }
223
224 // now need to shift numerator by same amount as denominator
225 // first, increase length of numerator in case we need more room to shift
226 num_dig[*num_len] = 0;
227 ++(*num_len);
228 for (mpz_dig_t *num = num_dig, carry = 0; num < num_dig + *num_len; ++num) {
229 mpz_dig_t n = *num;
230 *num = ((n << norm_shift) | carry) & DIG_MASK;
231 carry = n >> (DIG_SIZE - norm_shift);
232 }
233
234 // cache the leading digit of the denominator
235 lead_den_digit = den_dig[den_len - 1];
236
237 // point num_dig to last digit in numerator
238 num_dig += *num_len - 1;
239
240 // calculate number of digits in quotient
241 *quo_len = *num_len - den_len;
242
243 // point to last digit to store for quotient
244 quo_dig += *quo_len - 1;
245
246 // keep going while we have enough digits to divide
247 while (*num_len > den_len) {
248 mpz_dbl_dig_t quo = (*num_dig << DIG_SIZE) | num_dig[-1];
249
250 // get approximate quotient
251 quo /= lead_den_digit;
252
253 // multiply quo by den and subtract from num get remainder
254 {
255 mpz_dbl_dig_signed_t borrow = 0;
256
257 for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) {
258 borrow += *n - quo * *d; // will overflow if DIG_SIZE >= 16
259 *n = borrow & DIG_MASK;
260 borrow >>= DIG_SIZE;
261 }
262 borrow += *num_dig; // will overflow if DIG_SIZE >= 16
263 *num_dig = borrow & DIG_MASK;
264 borrow >>= DIG_SIZE;
265
266 // adjust quotient if it is too big
267 for (; borrow != 0; --quo) {
268 mpz_dbl_dig_t carry = 0;
269 for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) {
270 carry += *n + *d;
271 *n = carry & DIG_MASK;
272 carry >>= DIG_SIZE;
273 }
274 carry += *num_dig;
275 *num_dig = carry & DIG_MASK;
276 carry >>= DIG_SIZE;
277
278 borrow += carry;
279 }
280 }
281
282 // store this digit of the quotient
283 *quo_dig = quo & DIG_MASK;
284 --quo_dig;
285
286 // move down to next digit of numerator
287 --num_dig;
288 --(*num_len);
289 }
290
291 // unnormalise denomenator
292 for (mpz_dig_t *den = den_dig + den_len - 1, carry = 0; den >= den_dig; --den) {
293 mpz_dig_t d = *den;
294 *den = ((d >> norm_shift) | carry) & DIG_MASK;
295 carry = d << (DIG_SIZE - norm_shift);
296 }
297
298 // unnormalise numerator (remainder now)
299 for (mpz_dig_t *num = orig_num_dig + *num_len - 1, carry = 0; num >= orig_num_dig; --num) {
300 mpz_dig_t n = *num;
301 *num = ((n >> norm_shift) | carry) & DIG_MASK;
302 carry = n << (DIG_SIZE - norm_shift);
303 }
304
305 // strip trailing zeros
306
307 while (*quo_len > 0 && orig_quo_dig[*quo_len - 1] == 0) {
308 --(*quo_len);
309 }
310
311 while (*num_len > 0 && orig_num_dig[*num_len - 1] == 0) {
312 --(*num_len);
313 }
314}
315
316#define MIN_ALLOC (4)
317#define ALIGN_ALLOC (2)
Damien Georgeaca14122014-02-24 21:32:52 +0000318#define NUM_DIG_FOR_INT (sizeof(machine_int_t) * 8 / DIG_SIZE + 1)
Damien George438c88d2014-02-22 19:25:23 +0000319
320static const uint log_base2_floor[] = {
321 0,
322 0, 1, 1, 2,
323 2, 2, 2, 3,
324 3, 3, 3, 3,
325 3, 3, 3, 4,
326 4, 4, 4, 4,
327 4, 4, 4, 4,
328 4, 4, 4, 4,
329 4, 4, 4, 5
330};
331
Damien Georgeaca14122014-02-24 21:32:52 +0000332bool mpz_int_is_sml_int(machine_int_t i) {
Damien George438c88d2014-02-22 19:25:23 +0000333 return -(1 << DIG_SIZE) < i && i < (1 << DIG_SIZE);
334}
335
336void mpz_init_zero(mpz_t *z) {
337 z->alloc = 0;
338 z->neg = 0;
339 z->len = 0;
340 z->dig = NULL;
341}
342
343void mpz_init_from_int(mpz_t *z, machine_int_t val) {
344 mpz_init_zero(z);
345 mpz_set_from_int(z, val);
346}
347
348void mpz_deinit(mpz_t *z) {
349 if (z != NULL) {
350 m_del(mpz_dig_t, z->dig, z->alloc);
351 }
352}
353
354mpz_t *mpz_zero(void) {
355 mpz_t *z = m_new_obj(mpz_t);
356 mpz_init_zero(z);
357 return z;
358}
359
360mpz_t *mpz_from_int(machine_int_t val) {
361 mpz_t *z = mpz_zero();
362 mpz_set_from_int(z, val);
363 return z;
364}
365
366mpz_t *mpz_from_str(const char *str, uint len, bool neg, uint base) {
367 mpz_t *z = mpz_zero();
368 mpz_set_from_str(z, str, len, neg, base);
369 return z;
370}
371
372void mpz_free(mpz_t *z) {
373 if (z != NULL) {
374 m_del(mpz_dig_t, z->dig, z->alloc);
375 m_del_obj(mpz_t, z);
376 }
377}
378
379STATIC void mpz_need_dig(mpz_t *z, uint need) {
380 uint alloc;
381 if (need < MIN_ALLOC) {
382 alloc = MIN_ALLOC;
383 } else {
384 alloc = (need + ALIGN_ALLOC) & (~(ALIGN_ALLOC - 1));
385 }
386
387 if (z->dig == NULL || z->alloc < alloc) {
388 z->dig = m_renew(mpz_dig_t, z->dig, z->alloc, alloc);
389 z->alloc = alloc;
390 }
391}
392
393mpz_t *mpz_clone(const mpz_t *src) {
394 mpz_t *z = m_new_obj(mpz_t);
395 z->alloc = src->alloc;
396 z->neg = src->neg;
397 z->len = src->len;
398 if (src->dig == NULL) {
399 z->dig = NULL;
400 } else {
401 z->dig = m_new(mpz_dig_t, z->alloc);
402 memcpy(z->dig, src->dig, src->alloc * sizeof(mpz_dig_t));
403 }
404 return z;
405}
406
407void mpz_set(mpz_t *dest, const mpz_t *src) {
408 mpz_need_dig(dest, src->len);
409 dest->neg = src->neg;
410 dest->len = src->len;
411 memcpy(dest->dig, src->dig, src->len * sizeof(mpz_dig_t));
412}
413
414void mpz_set_from_int(mpz_t *z, machine_int_t val) {
415 mpz_need_dig(z, NUM_DIG_FOR_INT);
416
417 if (val < 0) {
418 z->neg = 1;
419 val = -val;
420 } else {
421 z->neg = 0;
422 }
423
424 z->len = 0;
425 while (val > 0) {
426 z->dig[z->len++] = val & DIG_MASK;
427 val >>= DIG_SIZE;
428 }
429}
430
431// returns number of bytes from str that were processed
432uint mpz_set_from_str(mpz_t *z, const char *str, uint len, bool neg, uint base) {
433 assert(base < 36);
434
435 const char *cur = str;
436 const char *top = str + len;
437
438 mpz_need_dig(z, len * 8 / DIG_SIZE + 1);
439
440 if (neg) {
441 z->neg = 1;
442 } else {
443 z->neg = 0;
444 }
445
446 z->len = 0;
447 for (; cur < top; ++cur) { // XXX UTF8 next char
448 //uint v = char_to_numeric(cur#); // XXX UTF8 get char
449 uint v = *cur;
450 if ('0' <= v && v <= '9') {
451 v -= '0';
452 } else if ('A' <= v && v <= 'Z') {
453 v -= 'A' - 10;
454 } else if ('a' <= v && v <= 'z') {
455 v -= 'a' - 10;
456 } else {
457 break;
458 }
459 if (v >= base) {
460 break;
461 }
462 z->len = mpn_mul_dig_add_dig(z->dig, z->len, base, v);
463 }
464
465 return cur - str;
466}
467
468bool mpz_is_zero(const mpz_t *z) {
469 return z->len == 0;
470}
471
472bool mpz_is_pos(const mpz_t *z) {
473 return z->len > 0 && z->neg == 0;
474}
475
476bool mpz_is_neg(const mpz_t *z) {
477 return z->len > 0 && z->neg != 0;
478}
479
480bool mpz_is_odd(const mpz_t *z) {
481 return z->len > 0 && (z->dig[0] & 1) != 0;
482}
483
484bool mpz_is_even(const mpz_t *z) {
485 return z->len == 0 || (z->dig[0] & 1) == 0;
486}
487
488int mpz_cmp(const mpz_t *z1, const mpz_t *z2) {
489 int cmp = z2->neg - z1->neg;
490 if (cmp != 0) {
491 return cmp;
492 }
493 cmp = mpn_cmp(z1->dig, z1->len, z2->dig, z2->len);
494 if (z1->neg != 0) {
495 cmp = -cmp;
496 }
497 return cmp;
498}
499
Damien Georgeaca14122014-02-24 21:32:52 +0000500int mpz_cmp_sml_int(const mpz_t *z, machine_int_t sml_int) {
Damien George438c88d2014-02-22 19:25:23 +0000501 int cmp;
502 if (z->neg == 0) {
503 if (sml_int < 0) return 1;
504 if (sml_int == 0) {
505 if (z->len == 0) return 0;
506 return 1;
507 }
508 if (z->len == 0) return -1;
509 assert(sml_int < (1 << DIG_SIZE));
510 if (z->len != 1) return 1;
511 cmp = z->dig[0] - sml_int;
512 } else {
513 if (sml_int > 0) return -1;
514 if (sml_int == 0) {
515 if (z->len == 0) return 0;
516 return -1;
517 }
518 if (z->len == 0) return 1;
519 assert(sml_int > -(1 << DIG_SIZE));
520 if (z->len != 1) return -1;
521 cmp = -z->dig[0] - sml_int;
522 }
523 if (cmp < 0) return -1;
524 if (cmp > 0) return 1;
525 return 0;
526}
527
528/* not finished
529mpz_t *mpz_shl(mpz_t *dest, const mpz_t *lhs, int rhs)
530{
531 if (dest != lhs)
532 dest = mpz_set(dest, lhs);
533
534 if (dest.len == 0 || rhs == 0)
535 return dest;
536
537 if (rhs < 0)
538 return mpz_shr(dest, dest, -rhs);
539
540 printf("mpz_shl: not implemented\n");
541
542 return dest;
543}
544
545mpz_t *mpz_shr(mpz_t *dest, const mpz_t *lhs, int rhs)
546{
547 if (dest != lhs)
548 dest = mpz_set(dest, lhs);
549
550 if (dest.len == 0 || rhs == 0)
551 return dest;
552
553 if (rhs < 0)
554 return mpz_shl(dest, dest, -rhs);
555
556 dest.len = mpn_shr(dest.len, dest.dig, rhs);
557 dest.dig[dest.len .. dest->alloc] = 0;
558
559 return dest;
560}
561*/
562
563
564#if 0
565these functions are unused
566
567/* returns abs(z)
568*/
569mpz_t *mpz_abs(const mpz_t *z) {
570 mpz_t *z2 = mpz_clone(z);
571 z2->neg = 0;
572 return z2;
573}
574
575/* returns -z
576*/
577mpz_t *mpz_neg(const mpz_t *z) {
578 mpz_t *z2 = mpz_clone(z);
579 z2->neg = 1 - z2->neg;
580 return z2;
581}
582
583/* returns lhs + rhs
584 can have lhs, rhs the same
585*/
586mpz_t *mpz_add(const mpz_t *lhs, const mpz_t *rhs) {
587 mpz_t *z = mpz_zero();
588 mpz_add_inpl(z, lhs, rhs);
589 return z;
590}
591
592/* returns lhs - rhs
593 can have lhs, rhs the same
594*/
595mpz_t *mpz_sub(const mpz_t *lhs, const mpz_t *rhs) {
596 mpz_t *z = mpz_zero();
597 mpz_sub_inpl(z, lhs, rhs);
598 return z;
599}
600
601/* returns lhs * rhs
602 can have lhs, rhs the same
603*/
604mpz_t *mpz_mul(const mpz_t *lhs, const mpz_t *rhs) {
605 mpz_t *z = mpz_zero();
606 mpz_mul_inpl(z, lhs, rhs);
607 return z;
608}
609
610/* returns lhs ** rhs
611 can have lhs, rhs the same
612*/
613mpz_t *mpz_pow(const mpz_t *lhs, const mpz_t *rhs) {
614 mpz_t *z = mpz_zero();
615 mpz_pow_inpl(z, lhs, rhs);
616 return z;
617}
618#endif
619
620/* computes dest = abs(z)
621 can have dest, z the same
622*/
623void mpz_abs_inpl(mpz_t *dest, const mpz_t *z) {
624 if (dest != z) {
625 mpz_set(dest, z);
626 }
627 dest->neg = 0;
628}
629
630/* computes dest = -z
631 can have dest, z the same
632*/
633void mpz_neg_inpl(mpz_t *dest, const mpz_t *z) {
634 if (dest != z) {
635 mpz_set(dest, z);
636 }
637 dest->neg = 1 - dest->neg;
638}
639
640/* computes dest = lhs + rhs
641 can have dest, lhs, rhs the same
642*/
643void mpz_add_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) {
644 if (mpn_cmp(lhs->dig, lhs->len, rhs->dig, rhs->len) < 0) {
645 const mpz_t *temp = lhs;
646 lhs = rhs;
647 rhs = temp;
648 }
649
650 if (lhs->neg == rhs->neg) {
651 mpz_need_dig(dest, lhs->len + 1);
652 dest->len = mpn_add(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len);
653 } else {
654 mpz_need_dig(dest, lhs->len);
655 dest->len = mpn_sub(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len);
656 }
657
658 dest->neg = lhs->neg;
659}
660
661/* computes dest = lhs - rhs
662 can have dest, lhs, rhs the same
663*/
664void mpz_sub_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) {
665 bool neg = false;
666
667 if (mpn_cmp(lhs->dig, lhs->len, rhs->dig, rhs->len) < 0) {
668 const mpz_t *temp = lhs;
669 lhs = rhs;
670 rhs = temp;
671 neg = true;
672 }
673
674 if (lhs->neg != rhs->neg) {
675 mpz_need_dig(dest, lhs->len + 1);
676 dest->len = mpn_add(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len);
677 } else {
678 mpz_need_dig(dest, lhs->len);
679 dest->len = mpn_sub(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len);
680 }
681
682 if (neg) {
683 dest->neg = 1 - lhs->neg;
684 } else {
685 dest->neg = lhs->neg;
686 }
687}
688
689/* computes dest = lhs * rhs
690 can have dest, lhs, rhs the same
691*/
692void mpz_mul_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs)
693{
694 if (lhs->len == 0 || rhs->len == 0) {
695 return mpz_set_from_int(dest, 0);
696 }
697
698 mpz_t *temp = NULL;
699 if (lhs == dest) {
700 lhs = temp = mpz_clone(lhs);
701 if (rhs == dest) {
702 rhs = lhs;
703 }
704 } else if (rhs == dest) {
705 rhs = temp = mpz_clone(rhs);
706 }
707
708 mpz_need_dig(dest, lhs->len + rhs->len); // min mem l+r-1, max mem l+r
709 memset(dest->dig, 0, dest->alloc * sizeof(mpz_dig_t));
710 dest->len = mpn_mul(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len);
711
712 if (lhs->neg == rhs->neg) {
713 dest->neg = 0;
714 } else {
715 dest->neg = 1;
716 }
717
718 mpz_free(temp);
719}
720
721/* computes dest = lhs ** rhs
722 can have dest, lhs, rhs the same
723*/
724void mpz_pow_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) {
725 if (lhs->len == 0 || rhs->neg != 0) {
726 return mpz_set_from_int(dest, 0);
727 }
728
729 if (rhs->len == 0) {
730 return mpz_set_from_int(dest, 1);
731 }
732
733 mpz_t *x = mpz_clone(lhs);
734 mpz_t *n = mpz_clone(rhs);
735
736 mpz_set_from_int(dest, 1);
737
738 while (n->len > 0) {
739 if (mpz_is_odd(n)) {
740 mpz_mul_inpl(dest, dest, x);
741 }
742 mpz_mul_inpl(x, x, x);
743 n->len = mpn_shr(n->dig, n->dig, n->len, 1);
744 }
745
746 mpz_free(x);
747 mpz_free(n);
748}
749
750/* computes gcd(z1, z2)
751 based on Knuth's modified gcd algorithm (I think?)
752 gcd(z1, z2) >= 0
753 gcd(0, 0) = 0
754 gcd(z, 0) = abs(z)
755*/
756mpz_t *mpz_gcd(const mpz_t *z1, const mpz_t *z2) {
757 if (z1->len == 0) {
758 mpz_t *a = mpz_clone(z2);
759 a->neg = 0;
760 return a;
761 } else if (z2->len == 0) {
762 mpz_t *a = mpz_clone(z1);
763 a->neg = 0;
764 return a;
765 }
766
767 mpz_t *a = mpz_clone(z1);
768 mpz_t *b = mpz_clone(z2);
769 mpz_t c; mpz_init_zero(&c);
770 a->neg = 0;
771 b->neg = 0;
772
773 for (;;) {
774 if (mpz_cmp(a, b) < 0) {
775 if (a->len == 0) {
776 mpz_free(a);
777 mpz_deinit(&c);
778 return b;
779 }
780 mpz_t *t = a; a = b; b = t;
781 }
782 if (!(b->len >= 2 || (b->len == 1 && b->dig[0] > 1))) { // compute b > 0; could be mpz_cmp_small_int(b, 1) > 0
783 break;
784 }
785 mpz_set(&c, b);
786 do {
787 mpz_add_inpl(&c, &c, &c);
788 } while (mpz_cmp(&c, a) <= 0);
789 c.len = mpn_shr(c.dig, c.dig, c.len, 1);
790 mpz_sub_inpl(a, a, &c);
791 }
792
793 mpz_deinit(&c);
794
795 if (b->len == 1 && b->dig[0] == 1) { // compute b == 1; could be mpz_cmp_small_int(b, 1) == 0
796 mpz_free(a);
797 return b;
798 } else {
799 mpz_free(b);
800 return a;
801 }
802}
803
804/* computes lcm(z1, z2)
805 = abs(z1) / gcd(z1, z2) * abs(z2)
806 lcm(z1, z1) >= 0
807 lcm(0, 0) = 0
808 lcm(z, 0) = 0
809*/
810mpz_t *mpz_lcm(const mpz_t *z1, const mpz_t *z2)
811{
812 if (z1->len == 0 || z2->len == 0)
813 return mpz_zero();
814
815 mpz_t *gcd = mpz_gcd(z1, z2);
816 mpz_t *quo = mpz_zero();
817 mpz_t *rem = mpz_zero();
818 mpz_divmod_inpl(quo, rem, z1, gcd);
819 mpz_mul_inpl(rem, quo, z2);
820 mpz_free(gcd);
821 mpz_free(quo);
822 rem->neg = 0;
823 return rem;
824}
825
826/* computes new integers in quo and rem such that:
827 quo * rhs + rem = lhs
828 0 <= rem < rhs
829 can have lhs, rhs the same
830*/
831void mpz_divmod(const mpz_t *lhs, const mpz_t *rhs, mpz_t **quo, mpz_t **rem) {
832 *quo = mpz_zero();
833 *rem = mpz_zero();
834 mpz_divmod_inpl(*quo, *rem, lhs, rhs);
835}
836
837/* computes new integers in quo and rem such that:
838 quo * rhs + rem = lhs
839 0 <= rem < rhs
840 can have lhs, rhs the same
841*/
842void mpz_divmod_inpl(mpz_t *dest_quo, mpz_t *dest_rem, const mpz_t *lhs, const mpz_t *rhs) {
843 if (rhs->len == 0) {
844 mpz_set_from_int(dest_quo, 0);
845 mpz_set_from_int(dest_rem, 0);
846 return;
847 }
848
849 mpz_need_dig(dest_quo, lhs->len + 1); // +1 necessary?
850 memset(dest_quo->dig, 0, (lhs->len + 1) * sizeof(mpz_dig_t));
851 dest_quo->len = 0;
852 mpz_need_dig(dest_rem, lhs->len + 1); // +1 necessary?
853 mpz_set(dest_rem, lhs);
854 //rhs->dig[rhs->len] = 0;
855 mpn_div(dest_rem->dig, &dest_rem->len, rhs->dig, rhs->len, dest_quo->dig, &dest_quo->len);
856
857 if (lhs->neg != rhs->neg) {
858 dest_quo->neg = 1;
859 }
860}
861
862#if 0
863these functions are unused
864
865/* computes floor(lhs / rhs)
866 can have lhs, rhs the same
867*/
868mpz_t *mpz_div(const mpz_t *lhs, const mpz_t *rhs) {
869 mpz_t *quo = mpz_zero();
870 mpz_t rem; mpz_init_zero(&rem);
871 mpz_divmod_inpl(quo, &rem, lhs, rhs);
872 mpz_deinit(&rem);
873 return quo;
874}
875
876/* computes lhs % rhs ( >= 0)
877 can have lhs, rhs the same
878*/
879mpz_t *mpz_mod(const mpz_t *lhs, const mpz_t *rhs) {
880 mpz_t quo; mpz_init_zero(&quo);
881 mpz_t *rem = mpz_zero();
882 mpz_divmod_inpl(&quo, rem, lhs, rhs);
883 mpz_deinit(&quo);
884 return rem;
885}
886#endif
887
Damien Georgeaca14122014-02-24 21:32:52 +0000888machine_int_t mpz_as_int(const mpz_t *i) {
889 machine_int_t val = 0;
Damien George438c88d2014-02-22 19:25:23 +0000890 mpz_dig_t *d = i->dig + i->len;
891
892 while (--d >= i->dig)
893 {
Damien Georgeaca14122014-02-24 21:32:52 +0000894 machine_int_t oldval = val;
Damien George438c88d2014-02-22 19:25:23 +0000895 val = (val << DIG_SIZE) | *d;
896 if (val < oldval)
897 {
898 if (i->neg == 0) {
899 return 0x7fffffff;
900 } else {
901 return 0x80000000;
902 }
903 }
904 }
905
906 if (i->neg != 0) {
907 val = -val;
908 }
909
910 return val;
911}
912
913machine_float_t mpz_as_float(const mpz_t *i) {
914 machine_float_t val = 0;
915 mpz_dig_t *d = i->dig + i->len;
916
917 while (--d >= i->dig) {
918 val = val * (1 << DIG_SIZE) + *d;
919 }
920
921 if (i->neg != 0) {
922 val = -val;
923 }
924
925 return val;
926}
927
928uint mpz_as_str_size(const mpz_t *i, uint base) {
929 if (base < 2 || base > 32) {
930 return 0;
931 }
932
933 return i->len * DIG_SIZE / log_base2_floor[base] + 2 + 1; // +1 for null byte termination
934}
935
936char *mpz_as_str(const mpz_t *i, uint base) {
937 char *s = m_new(char, mpz_as_str_size(i, base));
938 mpz_as_str_inpl(i, base, s);
939 return s;
940}
941
942// assumes enough space as calculated by mpz_as_str_size
943// returns length of string, not including null byte
944uint mpz_as_str_inpl(const mpz_t *i, uint base, char *str) {
945 if (str == NULL || base < 2 || base > 32) {
946 str[0] = 0;
947 return 0;
948 }
949
950 uint ilen = i->len;
951
952 if (ilen == 0) {
953 str[0] = '0';
954 str[1] = 0;
955 return 1;
956 }
957
958 // make a copy of mpz digits
959 mpz_dig_t *dig = m_new(mpz_dig_t, ilen);
960 memcpy(dig, i->dig, ilen * sizeof(mpz_dig_t));
961
962 // convert
963 char *s = str;
964 bool done;
965 do {
966 mpz_dig_t *d = dig + ilen;
967 mpz_dbl_dig_t a = 0;
968
969 // compute next remainder
970 while (--d >= dig) {
971 a = (a << DIG_SIZE) | *d;
972 *d = a / base;
973 a %= base;
974 }
975
976 // convert to character
977 a += '0';
978 if (a > '9') {
979 a += 'a' - '9' - 1;
980 }
981 *s++ = a;
982
983 // check if number is zero
984 done = true;
985 for (d = dig; d < dig + ilen; ++d) {
986 if (*d != 0) {
987 done = false;
988 break;
989 }
990 }
991 } while (!done);
992
993 if (i->neg != 0) {
994 *s++ = '-';
995 }
996
997 // reverse string
998 for (char *u = str, *v = s - 1; u < v; ++u, --v) {
999 char temp = *u;
1000 *u = *v;
1001 *v = temp;
1002 }
1003
1004 s[0] = 0; // null termination
1005
1006 return s - str;
1007}
1008
1009#endif // MICROPY_LONGINT_IMPL == MICROPY_LONGINT_IMPL_MPZ