LAMMP 4.1.0
Lamina High-Precision Arithmetic Library
载入中...
搜索中...
未找到
binvert.c
浏览该文件的文档.
1/*
2 * LAMMP - Copyright (C) 2025-2026 HJimmyK(Jericho Knox)
3 * This file is part of lammp, under the GNU LGPL v2 license.
4 * See LICENSE in the project root for the full license text.
5 */
6
7#include "../../../include/lammp/numth.h"
8#include "../../../include/lammp/lmmpn.h"
9#include "../../../include/lammp/impl/mparam.h"
10#include "../../../include/lammp/impl/tmp_alloc.h"
11
12
13static inline void binvert_mulhi_(mp_ptr dst, mp_srcptr xp, mp_srcptr ap, mp_size_t n, mp_ptr tp) {
15 lmmp_mul_n_(tp, xp, ap, n);
16 lmmp_copy(dst, tp + n, n);
17 } else {
18 mp_size_t m = lmmp_fft_next_size_((n * 2 + 1) >> 1);
19 lmmp_debug_assert(n * 2 > m && m >= n);
20 lmmp_mul_mersenne_(tp, m, xp, n, ap, n);
21 lmmp_dec(tp);
22 mp_size_t fn = m - n; // 从 tp+n 开始的长度
23 mp_size_t sn = n - fn; // 从 tp 开始的长度
24 lmmp_copy(dst, tp + n, fn);
25 lmmp_copy(dst + fn, tp, sn);
26 }
27}
28
29static inline void lmmp_sqrlo_n_(
30 mp_ptr restrict dst,
31 mp_srcptr restrict numa,
32 mp_size_t n,
33 mp_ptr restrict tp
34) {
35 if (n < MULLO_DC_THRESHOLD) {
36 lmmp_sqrlo_dc_(dst, numa, tp, n);
37 } else {
38 lmmp_mullo_fft_(dst, numa, numa, n, tp);
39 }
40}
41
42static inline void lmmp_mullo_n_(
43 mp_ptr restrict dst,
44 mp_srcptr restrict numa,
45 mp_srcptr restrict numb,
46 mp_size_t n,
47 mp_ptr restrict tp
48) {
49 if (n < MULLO_DC_THRESHOLD) {
50 lmmp_mullo_dc_(dst, numa, numb, tp, n);
51 } else {
52 lmmp_mullo_fft_(dst, numa, numb, n, tp);
53 }
54}
55
56/*
57balanced:
58 a := [numa,2*n]
59 we neead to find x such that x * a == 1 mod B^2n
60 we know that a == a_lo + a_hi * B^n
61 and x_lo == a_lo ^ -1 mod B^n
62 means x_lo * a_lo == 1 + k * B^n and k < B^n
63
64 x = x_lo * (2 - a * x_lo) mod B^2n
65 = x_lo * (2 - a_lo * x_lo - a_hi * x_lo * B^n) mod B^2n
66 = x_lo * (1 - k * B^n - a_hi * x_lo * B^n) mod B^2n
67 = x_lo - (k * x_lo + a_hi * x_lo^2) * B^n mod B^2n
68-----------------------------------------------------------------------------
69unbalanced:
70 a := [numa,na]
71 我们需要求x,使得x * a == 1 mod B^n ,同时n远远大于na
72 我们可以求出 x0 = a ^ -1 mod B^na,这是一个平衡的逆元
73 接下来,我们使用线性递推法来求,我们以na个limb为基本处理单元
74 假定现在已经求出 t 个,即 Xt = X0 + X1*B^na + X2*B^2na +... + X{t-1}*B^(t-1)*na
75 且满足 a*Xt == 1 mod B^t*na
76 可以写成 a*Xt = 1 + k * B^t*na, k < B^na
77 我们需要求出下一个 p,使得X{t+1} = X{t} + p*B^na
78 我们代入 a*X{t+1} = 1 mod B^(t+1)*na
79 可以得到
80 1 + k * B^t*na + a*p*B^t*na = 1 mod B^(t+1)*na
81 k + a*p = 0 mod B^na
82 p = -k * a^-1 mod B^na
83 此时,我们已经有了新的X{t+1},我们需要更新 k 为 k'
84 我们需要 k' 满足
85 a*X{t+1} = 1 + k' * B^(t+1)*na, k' < B^na
86 k' * B^na = k + a*p
87 k' = (k + a*p) / B^na
88*/
89
90
91void lmmp_binvert_n_dc_(mp_ptr restrict dst, mp_srcptr restrict numa, mp_size_t n, mp_ptr restrict tp) {
92 lmmp_param_assert(dst != NULL && tp != NULL);
93 lmmp_param_assert(numa != NULL && n > 0);
94 lmmp_param_assert(numa[0] % 2 == 1);
95 if (n == 1) {
96 dst[0] = lmmp_binvert_ulong_(numa[0]);
97 } else if (n == 2) {
98 lmmp_binvert_2_(dst, numa);
99 } else if (n == 3) {
100 lmmp_binvert_3_(dst, numa);
101 } else if (n == 4) {
102 lmmp_binvert_4_(dst, numa);
103 } else if (n % 2 == 0) {
104 mp_size_t halfn = n / 2;
105
106#define k (tp) // [tp, halfn]
107#define alo (numa) // [numa, halfn]
108#define ahi (numa + halfn) // [numa+halfn, halfn]
109#define xlo (dst) // [dst, halfn]
110#define xhi (dst + halfn) // [dst+halfn, halfn]
111#define xlo_sqr (tp + halfn) // [tp+halfn, halfn]
112#define xlo_sqr_mul_ahi (tp + 2 * halfn) // [tp+2*halfn, halfn]
113#define scratch (tp + 3 * halfn) // [tp+3*halfn,2*halfn]
114// ________________________________________________________________
115// tp : |_________________________5*(n+1)/2____________________________|
116// | k | xlo_sqr | xlo_sqr_mul_ahi | scratch | remaining |
117// |_halfn_|__halfn__|______halfn______|___2*halfn___| |
118
119 lmmp_binvert_n_dc_(xlo, alo, halfn, tp);
120 binvert_mulhi_(k, xlo, alo, halfn, tp + halfn);
123 lmmp_mullo_n_(xhi, xlo, k, halfn, scratch);
125 lmmp_not_(xhi, xhi, halfn);
126 lmmp_inc(xhi);
127 } else {
128 mp_size_t halfn = n / 2 + 1;
129 mp_size_t ahin = n - halfn;
130
131#define k (tp) // [tp, halfn]
132#define alo (numa) // [numa, halfn]
133#define ahi (numa + halfn) // [numa+halfn, ahin]
134#define xlo (dst) // [dst, halfn]
135#define xhi (dst + halfn) // [dst+halfn, ahin]
136#define xlo_sqr (tp + halfn) // [tp+halfn, ahin]
137#define xlo_sqr_mul_ahi (tp + 2 * halfn) // [tp+2*halfn, ahin]
138#define scratch (tp + 3 * halfn) // [tp+3*halfn, 2*ahin]
139// ________________________________________________________________
140// tp : |_________________________5*(n+1)/2____________________________|
141// | k | xlo_sqr | xlo_sqr_mul_ahi | scratch | remaining |
142// |__halfn__|__halfn__|______halfn______|___2*ahin___| |
143
144 lmmp_binvert_n_dc_(xlo, alo, halfn, tp);
145 binvert_mulhi_(k, xlo, alo, halfn, tp + halfn);
148 lmmp_mullo_n_(xhi, xlo, k, ahin, scratch);
150 lmmp_not_(xhi, xhi, ahin);
151 lmmp_inc(xhi);
152 }
153}
154#undef k
155#undef alo
156#undef ahi
157#undef xlo
158#undef xhi
159#undef xlo_sqr
160#undef xlo_sqr_mul_ahi
161#undef scratch
#define k
#define ahi
void lmmp_binvert_n_dc_(mp_ptr restrict dst, mp_srcptr restrict numa, mp_size_t n, mp_ptr restrict tp)
Definition binvert.c:91
#define xlo_sqr_mul_ahi
#define xlo_sqr
static void binvert_mulhi_(mp_ptr dst, mp_srcptr xp, mp_srcptr ap, mp_size_t n, mp_ptr tp)
Definition binvert.c:13
#define xhi
#define scratch
#define alo
static void lmmp_sqrlo_n_(mp_ptr restrict dst, mp_srcptr restrict numa, mp_size_t n, mp_ptr restrict tp)
Definition binvert.c:29
#define xlo
static void lmmp_mullo_n_(mp_ptr restrict dst, mp_srcptr restrict numa, mp_srcptr restrict numb, mp_size_t n, mp_ptr restrict tp)
Definition binvert.c:42
mp_limb_t * mp_ptr
Definition lmmp.h:215
#define lmmp_copy(dst, src, n)
Definition lmmp.h:364
uint64_t mp_size_t
Definition lmmp.h:212
#define lmmp_debug_assert(x)
Definition lmmp.h:387
const mp_limb_t * mp_srcptr
Definition lmmp.h:216
#define lmmp_param_assert(x)
Definition lmmp.h:398
void lmmp_mul_mersenne_(mp_ptr dst, mp_size_t rn, mp_srcptr numa, mp_size_t na, mp_srcptr numb, mp_size_t nb)
梅森数模乘法 [dst,rn] = [numa,na]*[numb,nb] mod B^rn-1
Definition mul_fft.c:752
#define lmmp_dec(p)
大数减1宏(预期无借位)
Definition lmmpn.h:973
void lmmp_mullo_dc_(mp_ptr dst, mp_srcptr numa, mp_srcptr numb, mp_ptr tp, mp_size_t n)
低位乘法 [dst,n] = [numa,n] * [numb,n] mod B^n
#define lmmp_inc(p)
大数加1宏(预期无进位)
Definition lmmpn.h:946
void lmmp_mul_n_(mp_ptr dst, mp_srcptr numa, mp_srcptr numb, mp_size_t n)
等长大数乘法操作 [dst,2*n] = [numa,n] * [numb,n]
Definition mul.c:99
mp_size_t lmmp_fft_next_size_(mp_size_t n)
计算满足 >=n 的最小费马/梅森乘法可行尺寸
Definition mul_fft.c:84
void lmmp_sqrlo_dc_(mp_ptr dst, mp_srcptr numa, mp_ptr tp, mp_size_t n)
低位平方 [dst,n] = [numa,n]^2 mod B^n
void lmmp_mullo_fft_(mp_ptr dst, mp_srcptr numa, mp_srcptr numb, mp_size_t n, mp_ptr scratch)
低位FFT乘法 [dst,n] = [numa,n] * [numb,n] mod B^n
Definition mullo.c:11
void lmmp_not_(mp_ptr dst, mp_srcptr numa, mp_size_t na)
大数按位取反操作 [dst,na] = ~[numa,na] (对每个limb执行按位非操作)
mp_limb_t lmmp_add_n_(mp_ptr dst, mp_srcptr numa, mp_srcptr numb, mp_size_t n)
无进位的n位加法 [dst,n] = [numa,n] + [numb,n]
Definition add_n.c:71
#define BINVERT_MULHI_MERSENNE_THRESHOLD
Definition mparam.h:130
#define MULLO_DC_THRESHOLD
Definition mparam.h:59
#define tp
ulong lmmp_binvert_ulong_(ulong a)
计算 a 在2^64下的逆元
Definition binvert_1.c:33
void lmmp_binvert_2_(mp_ptr dst, mp_srcptr numa)
计算 [numa,2] 在B^2下的逆元
Definition binvert_1.c:47
void lmmp_binvert_3_(mp_ptr dst, mp_srcptr numa)
计算 [numa,3] 在B^3下的逆元
void lmmp_binvert_4_(mp_ptr dst, mp_srcptr numa)
计算 [numa,4] 在B^4下的逆元