7#include "../../include/lammp/impl/mparam.h"
8#include "../../include/lammp/impl/tmp_alloc.h"
9#include "../../include/lammp/lmmpn.h"
12#define _FFT_TABLE_ENTRY(n) {((mp_size_t)3 << (2 * (n) - 5)) + 1, (n)}
13#define _FFT_TABLE_ENTRY4(n) \
14 _FFT_TABLE_ENTRY(n), _FFT_TABLE_ENTRY((n) + 1), _FFT_TABLE_ENTRY((n) + 2), _FFT_TABLE_ENTRY((n) + 3)
88 n = (((n - 1) >>
k) + 1) <<
k;
132 lmmp_shr_(dst, numa + offset, lena, shr);
183 cc =
lmmp_shl_(dst + w, src, l - w, shl);
202 dst[l] = dst[0] < cc;
246 scyo =
lmmp_sub_nc_(numc + w + off, numa + off, numb + off, cursize, scyo);
247 acyo =
lmmp_add_nc_(numa + off, numa + off, numb + off, cursize, acyo);
249 shlcyo =
lmmp_shl_c_(numc + w + off, numc + w + off, cursize, shl, shlcyo);
252 ch = shlcyo + (-scyo << shl);
263 scyo =
lmmp_sub_nc_(numc + off - (l - w), numb + off, numa + off, cursize, scyo);
264 acyo =
lmmp_add_nc_(numa + off, numa + off, numb + off, cursize, acyo);
266 shlcyo =
lmmp_shl_c_(numc + off - (l - w), numc + off - (l - w), cursize, shl, shlcyo);
270 scyo = -scyo + numb[l] - numa[l];
271 acyo += numa[l] + numb[l];
293 numc[l] = numc[0] < chp;
327 lmmp_shr_c_(numb + off - (l - w), numb + off - (l - w), cursize, shr,
328 numb[off - (l - w) + cursize] << (
LIMB_BITS - shr));
329 bcyo =
lmmp_add_nc_(numc + off, numa + off, numb + off - (l - w), cursize, bcyo);
330 acyo =
lmmp_sub_nc_(numa + off, numa + off, numb + off - (l - w), cursize, acyo);
336 lmmp_shr_c_(numb + w + off, numb + w + off, cursize, shr, numb[off + w + cursize] << (
LIMB_BITS - shr));
337 bcyo =
lmmp_sub_nc_(numc + off, numa + off, numb + w + off, cursize, bcyo);
338 acyo =
lmmp_add_nc_(numa + off, numa + off, numb + w + off, cursize, acyo);
341 acyo += numb[l] >> shr;
342 bcyo = -bcyo - (numb[l] >> shr);
344 acyo -= numa[l - w - 1] < shrcyo;
345 numa[l - w - 1] -= shrcyo;
346 numc[l - w - 1] += shrcyo;
347 bcyo += numc[l - w - 1] < shrcyo;
406 for (
mp_size_t i = 0; i < Kq; i += dis) {
465 for (
mp_size_t i = 0; i < Kq; i += dis) {
534 if (nums[nlen - 1]) {
538 if (nums[nlen - 1] == 0 && nums[nlen - 2] >> (
LIMB_BITS - 1)) {
551 borrow = -nums[nlen - 1];
564 }
else if (roffset + nlen <= rn) {
565 lmmp_add_(dst + roffset, nums, nlen, dst + roffset, rhead - roffset);
566 rhead = roffset + nlen;
568 maxc +=
lmmp_add_(dst + roffset, nums, rn - roffset, dst + roffset, rhead - roffset);
569 maxc -=
lmmp_sub_(dst, dst, rn, nums + rn - roffset, nlen + roffset - rn);
601 int nsqr = pc1 != pc2;
630 n = (((n - 1) >>
k) + 1) <<
k;
637 for (
mp_size_t i = 0; i < K; ++i) pfca[i] = (
mp_ptr)(pfca + K) + i * nlen;
639 pfcb += (nlen + 1) <<
k;
640 for (
mp_size_t i = 0; i < K; ++i) pfcb[i] = (
mp_ptr)(pfcb + K) + i * nlen;
678 int nsqr = numa != numb || na != nb;
687 n = (((n - 1) >>
k) + 1) <<
k;
703 pfca[i] = (
mp_ptr)(pfca + K) + i * nlen;
705 coeflen = M + (i == K - 1);
706 coeflen =
LMMP_MIN(narest, coeflen);
719 pfcb += (nlen + 1) <<
k;
722 pfcb[i] = (
mp_ptr)(pfcb + K) + i * nlen;
724 coeflen = M + (i == K - 1);
725 coeflen =
LMMP_MIN(nbrest, coeflen);
753 int nsqr = numa != numb || na != nb;
764 n = (((n - 1) >> (
k - 1)) + 1) << (
k - 1);
780 pfca[i] = (
mp_ptr)(pfca + K) + i * nlen;
792 pfcb += (nlen + 1) <<
k;
795 pfcb[i] = (
mp_ptr)(pfcb + K) + i * nlen;
816 if (nums[nlen - 1]) {
832 }
else if (roffset + nlen <= rn) {
833 lmmp_add_(dst + roffset, nums, nlen, dst + roffset, rhead - roffset);
834 rhead = roffset + nlen;
836 maxc +=
lmmp_add_(dst + roffset, nums, rn - roffset, dst + roffset, rhead - roffset);
837 maxc +=
lmmp_add_(dst, dst, rn, nums + rn - roffset, nlen + roffset - rn);
873 int nsqr = numa != numb || na != nb;
883 n = (((n - 1) >>
k) + 1) <<
k;
915 pfca[i] = (
mp_ptr)(pfca + K) + i * nlen;
917 coeflen = M + (i == K - 1);
918 coeflen =
LMMP_MIN(narest, coeflen);
933 pfcb[i] = (
mp_ptr)(pfcb + K) + i * nlen;
935 coeflen = M + (i == K - 1);
936 coeflen =
LMMP_MIN(nbrest, coeflen);
971 int nsqr = numa != numb || na != nb;
983 n = (((n - 1) >> (
k - 1)) + 1) << (
k - 1);
1017 pfca[i] = (
mp_ptr)(pfca + K) + i * nlen;
1032 pfcb[i] = (
mp_ptr)(pfcb + K) + i * nlen;
1053 if (nums[nlen - 1]) {
1069 }
else if (roffset + nlen <= rn) {
1070 lmmp_add_(dst + roffset, nums, nlen, dst + roffset, rhead - roffset);
1071 rhead = roffset + nlen;
1073 maxc +=
lmmp_add_(dst + roffset, nums, rn - roffset, dst + roffset, rhead - roffset);
1074 maxc +=
lmmp_add_(dst, dst, rn, nums + rn - roffset, nlen + roffset - rn);
1099 if (
lmmp_add_(dst, numa, hn, numa + hn, na - hn))
1124 if (dst[hn - 1] < cy)
1127 if (na + nb == 2 * hn) {
1135 cy =
tp[hn] +
lmmp_sub_nc_(
tp + na + nb - hn, dst + na + nb - hn,
tp + na + nb - hn, 2 * hn - (na + nb), cy);
1162 if (
lmmp_add_(dst, numa, hn, numa + hn, na - hn))
1187 if (dst[hn - 1] < cy)
1190 if (na + nb == 2 * hn) {
1198 cy =
tp[hn] +
lmmp_sub_nc_(
tp + na + nb - hn, dst + na + nb - hn,
tp + na + nb - hn, 2 * hn - (na + nb), cy);
1215 sna = (hn << 1) - 1 - nb;
#define lmmp_copy(dst, src, n)
#define lmmp_zero(dst, n)
#define lmmp_debug_assert(x)
void * lmmp_alloc(size_t size)
内存分配函数(调用lmmp_heap_alloc_fn)
const mp_limb_t * mp_srcptr
void lmmp_free(void *ptr)
内存释放函数(调用lmmp_heap_free_fn)
#define lmmp_param_assert(x)
mp_limb_t lmmp_shlnot_(mp_ptr dst, mp_srcptr numa, mp_size_t na, mp_size_t shl)
左移后按位取反操作 [dst,na] = ~([numa,na] << shl),dst的低shl位填充1
static mp_limb_t lmmp_add_(mp_ptr dst, mp_srcptr numa, mp_size_t na, mp_srcptr numb, mp_size_t nb)
大数加法静态内联函数 [dst,na]=[numa,na]+[numb,nb]
mp_limb_t lmmp_shr1add_nc_(mp_ptr dst, mp_srcptr numa, mp_srcptr numb, mp_size_t n, mp_limb_t c)
带进位加法后右移1位 [dst,n] = ([numa,n] + [numb,n] + c) >> 1
mp_limb_t lmmp_shr_c_(mp_ptr dst, mp_srcptr numa, mp_size_t na, mp_size_t shr, mp_limb_t c)
带进位的大数右移操作 [dst,na] = [numa,na]>>shr,dst的高shr位填充c的高shr位
#define lmmp_dec(p)
大数减1宏(预期无借位)
static mp_limb_t lmmp_add_1_(mp_ptr dst, mp_srcptr numa, mp_size_t na, mp_limb_t x)
大数加单精度数静态内联函数 [dst,na]=[numa,na]+x
#define lmmp_inc(p)
大数加1宏(预期无进位)
mp_limb_t lmmp_shr_(mp_ptr dst, mp_srcptr numa, mp_size_t na, mp_size_t shr)
大数右移操作 [dst,na] = [numa,na] >> shr,dst的高shr位填充0
void lmmp_mul_(mp_ptr dst, mp_srcptr numa, mp_size_t na, mp_srcptr numb, mp_size_t nb)
不等长大数乘法操作 [dst,na+nb] = [numa,na] * [numb,nb]
void lmmp_sqr_(mp_ptr dst, mp_srcptr numa, mp_size_t na)
大数平方操作 [dst,2*na] = [numa,na]^2
void lmmp_mul_n_(mp_ptr dst, mp_srcptr numa, mp_srcptr numb, mp_size_t n)
等长大数乘法操作 [dst,2*n] = [numa,n] * [numb,n]
mp_limb_t lmmp_shl_c_(mp_ptr dst, mp_srcptr numa, mp_size_t na, mp_size_t shl, mp_limb_t c)
带进位的大数左移操作 [dst,na] = [numa,na]<<shl,dst的低shl位填充c的低shl位
mp_limb_t lmmp_add_nc_(mp_ptr dst, mp_srcptr numa, mp_srcptr numb, mp_size_t n, mp_limb_t c)
带进位的n位加法 [dst,n] = [numa,n] + [numb,n] + c
mp_limb_t lmmp_shl_(mp_ptr dst, mp_srcptr numa, mp_size_t na, mp_size_t shl)
大数左移操作 [dst,na] = [numa,na]<<shl,dst的低shl位填充0
static mp_limb_t lmmp_sub_(mp_ptr dst, mp_srcptr numa, mp_size_t na, mp_srcptr numb, mp_size_t nb)
大数减法静态内联函数 [dst,na]=[numa,na]-[numb,nb]
#define lmmp_dec_1(p, dec)
大数减指定值宏(预期无借位)
static mp_limb_t lmmp_sub_1_(mp_ptr dst, mp_srcptr numa, mp_size_t na, mp_limb_t x)
大数减单精度数静态内联函数 [dst,na]=[numa,na]-x
void lmmp_not_(mp_ptr dst, mp_srcptr numa, mp_size_t na)
大数按位取反操作 [dst,na] = ~[numa,na] (对每个limb执行按位非操作)
mp_limb_t lmmp_sub_n_(mp_ptr dst, mp_srcptr numa, mp_srcptr numb, mp_size_t n)
无借位的n位减法 [dst,n] = [numa,n] - [numb,n]
#define lmmp_inc_1(p, inc)
大数加指定值宏(预期无进位)
mp_limb_t lmmp_sub_nc_(mp_ptr dst, mp_srcptr numa, mp_srcptr numb, mp_size_t n, mp_limb_t c)
带借位的n位减法 [dst,n] = [numa,n] - [numb,n] - c
mp_limb_t lmmp_add_n_(mp_ptr dst, mp_srcptr numa, mp_srcptr numb, mp_size_t n)
无进位的n位加法 [dst,n] = [numa,n] + [numb,n]
static int lmmp_zero_q_(mp_srcptr p, mp_size_t n)
大数判零函数(内联)
#define MUL_FFT_MODF_THRESHOLD
static void lmmp_fft_shr_coef_(fft_memstack *ms, mp_ptr *coef, mp_size_t shr)
对模 2^n+1 的系数执行右移操作 右移shr位 = 左移(2n - shr)位(mod 2^n+1的循环特性)
void lmmp_mul_mersenne_(mp_ptr dst, mp_size_t rn, mp_srcptr numa, mp_size_t na, mp_srcptr numb, mp_size_t nb)
梅森数模乘法 [dst,rn] = [numa,na]*[numb,nb] mod B^rn-1
static void lmmp_mul_fermat_recurse_(fft_memstack *ms, mp_ptr *pc1, mp_ptr *pc2, mp_size_t K0)
费马变换乘法递归函数(核心乘法逻辑)
static void lmmp_mul_mersenne_single_(mp_ptr dst, mp_size_t rn, mp_srcptr numa, mp_size_t na, mp_srcptr numb, mp_size_t nb, fft_cache *GH)
static void lmmp_ifft_bfy_(fft_memstack *ms, mp_ptr *coef, mp_size_t wing, mp_size_t w)
FFT蝶形运算(Butterfly Operation) (a,b) = (a+(b>>w), a-(b>>w)) mod 2^n+1 a=[coef[0],ms->lenw+1],...
static void lmmp_fft_(fft_memstack *ms, mp_ptr *coef, mp_size_t k, mp_size_t w)
#define _FFT_TABLE_ENTRY4(n)
void lmmp_mul_fft_unbalance_(mp_ptr restrict dst, mp_srcptr restrict numa, mp_size_t na, mp_srcptr restrict numb, mp_size_t nb)
static void * lmmp_fft_memstack_(fft_memstack *ms, mp_size_t size)
FFT内存栈的分配/释放接口
static void lmmp_fft_shl_coef_(fft_memstack *ms, mp_ptr *coef, mp_size_t shl)
对模 2^n+1 的系数执行左移操作
mp_ptr temp_coef_mersenne
static void lmmp_mul_fft_cache_(mp_ptr dst, mp_size_t hn, mp_srcptr numa, mp_size_t na, mp_srcptr numb, mp_size_t nb, fft_cache *GH)
static void lmmp_mul_fermat_recombine_(fft_memstack *ms, mp_ptr dst, mp_ptr *pfca, mp_size_t K, mp_size_t k, mp_size_t n, mp_size_t M, mp_size_t rn)
费马变换 模 B^n+1 乘法的结果合并
static void lmmp_ifft_b1_(fft_memstack *ms, mp_ptr *coef, mp_size_t dis, mp_size_t k, mp_size_t w, mp_size_t w0)
void lmmp_mul_fermat_(mp_ptr dst, mp_size_t rn, mp_srcptr numa, mp_size_t na, mp_srcptr numb, mp_size_t nb)
费马数模乘法 [dst,rn+1]=[numa,na]*[numb,nb] mod B^rn+1
static void lmmp_mul_fft_cache_free_(fft_cache *GH)
static void lmmp_ifft_4_(fft_memstack *ms, mp_ptr *coef, mp_size_t k, mp_size_t w)
static void lmmp_fft_bfy_(fft_memstack *ms, mp_ptr *coef, mp_size_t wing, mp_size_t w)
FFT蝶形运算(Butterfly Operation) (a,b) = (a + b, (a-b) << w ) mod 2^n+1 a=[coef[0],ms->lenw+1],...
mp_size_t lmmp_fft_next_size_(mp_size_t n)
计算FFT运算所需的最小规整化长度(向上取整到2^k的倍数)
static mp_size_t lmmp_fft_best_k_(mp_size_t n)
查找对于 m>=n 的模 B^m+1 FFT运算的最优k值
static void lmmp_ifft_(fft_memstack *ms, mp_ptr *coef, mp_size_t k, mp_size_t w)
static void lmmp_fft_extract_coef_(mp_ptr dst, mp_srcptr numa, mp_size_t bitoffset, mp_size_t bits, mp_size_t lenw)
[dst,lenw+1] = [(bit*)numa+bitoffset, bits]
static const mp_size_t lmmp_fft_table_[][2]
fft_memstack msr_mersenne
void lmmp_mul_fft_(mp_ptr dst, mp_srcptr numa, mp_size_t na, mp_srcptr numb, mp_size_t nb)
FFT乘法运算 [dst,na+nb] = [numa,na] * [numb,nb]
static void lmmp_fft_b1_(fft_memstack *ms, mp_ptr *coef, mp_size_t dis, mp_size_t k, mp_size_t w, mp_size_t w0)
FFT递归函数
static void lmmp_fft_4_(fft_memstack *ms, mp_ptr *coef, mp_size_t k, mp_size_t w)
static void lmmp_mul_fermat_single_(mp_ptr dst, mp_size_t rn, mp_srcptr numa, mp_size_t na, mp_srcptr numb, mp_size_t nb, fft_cache *GH)
#define ALLOC_TYPE(n, type)