LAMMP 4.1.0
Lamina High-Precision Arithmetic Library
载入中...
搜索中...
未找到
binvert_1.c
浏览该文件的文档.
1/*
2 * LAMMP - Copyright (C) 2025-2026 HJimmyK(Jericho Knox)
3 * This file is part of lammp, under the GNU LGPL v2 license.
4 * See LICENSE in the project root for the full license text.
5 */
6
7#include "../../../include/lammp/numth.h"
8#include "../../../include/lammp/lmmpn.h"
9#include "../../../include/lammp/impl/longlong.h"
10
11/* from https://arxiv.org/abs/2204.04342 */
12
13static const unsigned char binv_tab[128] = {
14 1, 171, 205, 183, 57, 163, 197, 239, 241, 27, 61, 167, 41, 19, 53, 223, 225, 139, 173, 151, 25, 131,
15 165, 207, 209, 251, 29, 135, 9, 243, 21, 191, 193, 107, 141, 119, 249, 99, 133, 175, 177, 219, 253, 103,
16 233, 211, 245, 159, 161, 75, 109, 87, 217, 67, 101, 143, 145, 187, 221, 71, 201, 179, 213, 127, 129, 43,
17 77, 55, 185, 35, 69, 111, 113, 155, 189, 39, 169, 147, 181, 95, 97, 11, 45, 23, 153, 3, 37, 79,
18 81, 123, 157, 7, 137, 115, 149, 63, 65, 235, 13, 247, 121, 227, 5, 47, 49, 91, 125, 231, 105, 83,
19 117, 31, 33, 203, 237, 215, 89, 195, 229, 15, 17, 59, 93, 199, 73, 51, 85, 255};
20
22 lmmp_param_assert(a % 2 == 1);
23 ulong r, y;
24
25 r = binv_tab[(a / 2) & 0x7F]; /* 8 bits */
26 y = 1 - a * r;
27 r = r * (1 + y); /* 16 bits */
28 y *= y;
29 r = r * (1 + y); /* 32 bits */
30 return r;
31}
32
34 lmmp_param_assert(a % 2 == 1);
35 ulong r, y;
36
37 r = binv_tab[(a / 2) & 0x7F]; /* 8 bits */
38 y = 1 - a * r;
39 r = r * (1 + y); /* 16 bits */
40 y *= y;
41 r = r * (1 + y); /* 32 bits */
42 y *= y;
43 r = r * (1 + y); /* 64 bits */
44 return r;
45}
46
48 lmmp_param_assert(numa[0] % 2 == 1);
49 mp_limb_t k, t;
50 mp_limb_t a1 = numa[1];
51 mp_limb_t a0 = numa[0];
53 mp_limb_t z;
54 /*
55 xn * a0 == 1 + k * B
56 yn := xn * (2 - a * xn) mod B^2
57 := xn * (2 - a0 * xn - a1 * xn * B) mod B^2
58 := xn * (2 - 1 - k*B - a1 * xn * B) mod B^2
59 := xn * (1 - k*B - a1 * xn * B) mod B^2
60 := (xn - xn*k * B - a1 * xn^2 * B) mod B^2
61 */
62 _umul64to128_(a0, xn, &t, &k);
63 z = xn * k;
64 z += a1 * xn * xn;
65 dst[0] = xn;
66 dst[1] = -z;
67}
68
69static inline void _umul128to192_(uint64_t a_high, uint64_t a_low, uint64_t b_high, uint64_t b_low, uint64_t rr[3]) {
70 uint64_t p1_low, p1_high; // p1 = a_low × b_high
71 uint64_t p2_low, p2_high; // p2 = a_high × b_low
72 _umulx64to128_(a_low, b_low, rr, rr + 1);
73 _umulx64to128_(a_low, b_high, &p1_low, &p1_high);
74 _umulx64to128_(a_high, b_low, &p2_low, &p2_high);
75 /*
76 | res0 | res1 | res2 |
77 | p0l | p0h | |
78 | p1l | p1h |
79 | p2l | p2h |
80 | | p3l |
81 */
82 rr[1] += p1_low;
83 uint64_t carry = (rr[1] < p1_low) ? 1 : 0;
84 rr[1] += p2_low;
85 carry += (rr[1] < p2_low) ? 1 : 0;
86
87 rr[2] = a_high * b_high;
88 rr[2] += carry;
89 rr[2] += p1_high;
90 rr[2] += p2_high;
91}
92
93void lmmp_binvert_3_(mp_ptr restrict dst, mp_srcptr restrict numa) {
94 /*
95 a == a0 + a1 * B^2
96 xn * a0 == 1 + k * B^2
97 yn := xn * (2 - a * xn) mod B^4
98 := xn * (2 - a0 * xn - a1 * xn * B^2) mod B^4
99 := xn * (2 - 1 - k*B^2 - a1 * xn * B^2) mod B^4
100 := xn * (1 - k*B^2 - a1 * xn * B^2) mod B^4
101 := (xn - xn*k * B^2 - a1 * xn^2 * B^2) mod B^4
102 */
103 lmmp_binvert_2_(dst, numa);
104 mp_limb_t k[3];
105 mp_limb_t z;
106 mp_limb_t a2 = numa[2];
107 _umul128to192_(dst[1], dst[0], numa[1], numa[0], k);
108 lmmp_debug_assert(k[1] == 0 && k[0] == 1);
109#define xn (dst[0])
110#define k (k[2])
111 z = xn * k;
112 z += a2 * xn * xn;
113 dst[2] = -z;
114#undef xn
115#undef k
116}
117
118void lmmp_binvert_4_(mp_ptr restrict dst, mp_srcptr restrict numa) {
119 /*
120 a == a0 + a1 * B^2
121 xn * a0 == 1 + k * B^2
122 yn := xn * (2 - a * xn) mod B^4
123 := xn * (2 - a0 * xn - a1 * xn * B^2) mod B^4
124 := xn * (2 - 1 - k*B^2 - a1 * xn * B^2) mod B^4
125 := xn * (1 - k*B^2 - a1 * xn * B^2) mod B^4
126 := (xn - xn*k * B^2 - a1 * xn^2 * B^2) mod B^4
127 */
128 lmmp_binvert_2_(dst, numa);
129 mp_limb_t k[4];
130 mp_limb_t z[2];
131 mp_limb_t t[2];
132 _umul128to256_(dst[1], dst[0], numa[1], numa[0], k);
133 lmmp_debug_assert(k[1] == 0 && k[0] == 1);
134
135#define xn (dst)
136#define k (k + 2)
137 _umul128to128_(k[1], k[0], xn[1], xn[0], z);
138
139 _umul64to128_(xn[0], xn[0], t, t + 1);
140 t[1] += (xn[1] * xn[0]) << 1;
141 _umul128to128_(t[1], t[0], numa[3], numa[2], t);
142
143 _u128add(z, z, t);
144 dst[2] = 0;
145 dst[3] = 0;
146 _u128sub(dst + 2, dst + 2, z);
147
148#undef xn
149#undef k
150}
#define k
ulong lmmp_binvert_ulong_(ulong a)
计算 a 在2^64下的逆元
Definition binvert_1.c:33
uint lmmp_binvert_uint_(uint a)
计算 a 在2^32下的逆元
Definition binvert_1.c:21
void lmmp_binvert_2_(mp_ptr dst, mp_srcptr numa)
计算 [numa,2] 在B^2下的逆元
Definition binvert_1.c:47
void lmmp_binvert_3_(mp_ptr restrict dst, mp_srcptr restrict numa)
Definition binvert_1.c:93
#define xn
static void _umul128to192_(uint64_t a_high, uint64_t a_low, uint64_t b_high, uint64_t b_low, uint64_t rr[3])
Definition binvert_1.c:69
void lmmp_binvert_4_(mp_ptr restrict dst, mp_srcptr restrict numa)
Definition binvert_1.c:118
static const unsigned char binv_tab[128]
Definition binvert_1.c:13
mp_limb_t * mp_ptr
Definition lmmp.h:215
#define lmmp_debug_assert(x)
Definition lmmp.h:387
const mp_limb_t * mp_srcptr
Definition lmmp.h:216
uint64_t mp_limb_t
Definition lmmp.h:211
#define lmmp_param_assert(x)
Definition lmmp.h:398
#define _u128add(r, x, y)
Definition longlong.h:260
static void _umul64to128_(uint64_t a, uint64_t b, uint64_t *low, uint64_t *high)
Definition longlong.h:31
static void _umul128to256_(uint64_t a_high, uint64_t a_low, uint64_t b_high, uint64_t b_low, uint64_t rr[4])
Definition longlong.h:86
#define _u128sub(r, x, y)
Definition longlong.h:282
static void _umul128to128_(uint64_t a_high, uint64_t a_low, uint64_t b_high, uint64_t b_low, uint64_t rr[2])
Definition longlong.h:115
static void _umulx64to128_(uint64_t a, uint64_t b, uint64_t *low, uint64_t *high)
Definition longlong.h:78
#define a0
#define a1
#define a2
uint32_t uint
Definition numth.h:35
uint64_t ulong
Definition numth.h:36