LAMMP 4.1.0
Lamina High-Precision Arithmetic Library
载入中...
搜索中...
未找到
mul_toom44.c
浏览该文件的文档.
1/*
2 * LAMMP - Copyright (C) 2025-2026 HJimmyK(Jericho Knox)
3 * This file is part of lammp, under the GNU LGPL v2 license.
4 * See LICENSE in the project root for the full license text.
5 */
6
7#include "../../include/lammp/impl/toom_interp.h"
8
9/*
10Evaluate in: 0, +1, -1, +2, -2, 1/2, +inf
11
12 <-s--><--n--><--n--><--n-->
13 |-a3-|--a2--|--a1--|--a0--|
14 |b3-|--b2--|--b1--|--b0--|
15 <-t-><--n--><--n--><--n-->
16
17 v0 = a0 * b0 # A(0)*B(0)
18 v1 = ( a0+ a1+ a2+ a3)*( b0+ b1+ b2+ b3) # A(1)*B(1) ah <= 3 bh <= 3
19 vm1 = ( a0- a1+ a2- a3)*( b0- b1+ b2- b3) # A(-1)*B(-1) |ah| <= 1 |bh| <= 1
20 v2 = ( a0+2a1+4a2+8a3)*( b0+2b1+4b2+8b3) # A(2)*B(2) ah <= 14 bh <= 14
21 vm2 = ( a0-2a1+4a2-8a3)*( b0-2b1+4b2-8b3) # A(2)*B(2) ah <= 9 |bh| <= 9
22 vh = (8a0+4a1+2a2+ a3)*(8b0+4b1+2b2+ b3) # A(1/2)*B(1/2) ah <= 14 bh <= 14
23 vinf= a3 * b2 # A(inf)*B(inf)
24*/
25
26void lmmp_mul_toom44_(mp_ptr restrict dst, mp_srcptr restrict numa, mp_size_t na, mp_srcptr restrict numb, mp_size_t nb) {
27 lmmp_param_assert(na >= nb);
28 lmmp_param_assert(4 * na <= 5 * nb);
29 mp_size_t n, s, t;
30 mp_limb_t cy;
31 enum toom7_flags flags;
32
33#define a0 numa
34#define a1 (numa + n)
35#define a2 (numa + 2 * n)
36#define a3 (numa + 3 * n)
37#define b0 numb
38#define b1 (numb + n)
39#define b2 (numb + 2 * n)
40#define b3 (numb + 3 * n)
41
42 lmmp_debug_assert(na >= nb);
43
44 n = (na + 3) >> 2;
46 mp_ptr restrict scratch = SALLOC_TYPE(8 * n + 8, mp_limb_t);
47
48 s = na - 3 * n;
49 t = nb - 3 * n;
50
51 lmmp_debug_assert(0 < s && s <= n);
52 lmmp_debug_assert(0 < t && t <= n);
53 lmmp_debug_assert(s >= t);
54
55 /* NOTE: The multiplications to v2, vm2, vh and vm1 overwrites the
56 * following limb, so these must be computed in order, and we need a
57 * one limb gap to tp. */
58#define v0 dst /* 2n */
59#define v1 (dst + 2 * n) /* 2n+1 */
60#define vinf (dst + 6 * n) /* s+t */
61#define v2 scratch /* 2n+1 */
62#define vm2 (scratch + 2 * n + 1) /* 2n+1 */
63#define vh (scratch + 4 * n + 2) /* 2n+1 */
64#define vm1 (scratch + 6 * n + 3) /* 2n+1 */
65#define tp (scratch + 8 * n + 5)
66
67 /* apx and bpx must not overlap with v1 */
68#define apx dst /* n+1 */
69#define amx (dst + n + 1) /* n+1 */
70#define bmx (dst + 2 * n + 2) /* n+1 */
71#define bpx (dst + 4 * n + 2) /* n+1 */
72
73 /* Total scratch need: 8*n + 5 + scratch for recursive calls. This
74 gives roughly 32 n/3 + log term. */
75
76 /* Compute apx = a0 + 2 a1 + 4 a2 + 8 a3 and amx = a0 - 2 a1 + 4 a2 - 8 a3. */
77 flags = (enum toom7_flags)(toom7_w1_neg & lmmp_toom_eval_dgr3_pm2_(apx, amx, numa, n, s, tp));
78
79 /* Compute bpx = b0 + 2 b1 + 4 b2 + 8 b3 and bmx = b0 - 2 b1 + 4 b2 - 8 b3. */
80 flags = (enum toom7_flags)(flags ^ (toom7_w1_neg & lmmp_toom_eval_dgr3_pm2_(bpx, bmx, numb, n, t, tp)));
81
82 lmmp_mul_n_(v2, apx, bpx, n + 1); /* v2, 2n+1 limbs */
83 lmmp_mul_n_(vm2, amx, bmx, n + 1); /* vm2, 2n+1 limbs */
84
85 /* Compute apx = 8 a0 + 4 a1 + 2 a2 + a3 = (((2*a0 + a1) * 2 + a2) * 2 + a3 */
86
87 cy = lmmp_addshl1_n_(apx, a1, a0, n);
88 cy = 2 * cy + lmmp_addshl1_n_(apx, a2, apx, n);
89 if (s < n) {
90 mp_limb_t cy2;
91 cy2 = lmmp_addshl1_n_(apx, a3, apx, s);
92 apx[n] = 2 * cy + lmmp_shl_(apx + s, apx + s, n - s, 1);
93 lmmp_inc_1(apx + s, cy2);
94 } else
95 apx[n] = 2 * cy + lmmp_addshl1_n_(apx, a3, apx, n);
96
97
98 /* Compute bpx = 8 b0 + 4 b1 + 2 b2 + b3 = (((2*b0 + b1) * 2 + b2) * 2 + b3 */
99
100 cy = lmmp_addshl1_n_(bpx, b1, b0, n);
101 cy = 2 * cy + lmmp_addshl1_n_(bpx, b2, bpx, n);
102 if (t < n) {
103 mp_limb_t cy2;
104 cy2 = lmmp_addshl1_n_(bpx, b3, bpx, t);
105 bpx[n] = 2 * cy + lmmp_shl_(bpx + t, bpx + t, n - t, 1);
106 lmmp_inc_1(bpx + t, cy2);
107 } else
108 bpx[n] = 2 * cy + lmmp_addshl1_n_(bpx, b3, bpx, n);
109
110 lmmp_debug_assert(apx[n] < 15);
111 lmmp_debug_assert(bpx[n] < 15);
112
113 lmmp_mul_n_(vh, apx, bpx, n + 1); /* vh, 2n+1 limbs */
114
115 /* Compute apx = a0 + a1 + a2 + a3 and amx = a0 - a1 + a2 - a3. */
116 flags = (enum toom7_flags)(flags | (toom7_w3_neg & lmmp_toom_eval_dgr3_pm1_(apx, amx, numa, n, s, tp)));
117
118 /* Compute bpx = b0 + b1 + b2 + b3 and bmx = b0 - b1 + b2 - b3. */
119 flags = (enum toom7_flags)(flags ^ (toom7_w3_neg & lmmp_toom_eval_dgr3_pm1_(bpx, bmx, numb, n, t, tp)));
120
121 lmmp_mul_n_(vm1, amx, bmx, n + 1); /* vm1, 2n+1 limbs */
122 /* Clobbers amx, bmx. */
123 lmmp_mul_n_(v1, apx, bpx, n + 1); /* v1, 2n+1 limbs */
124
125 lmmp_mul_n_(v0, a0, b0, n);
126 if (s > t)
127 lmmp_mul_(vinf, a3, s, b3, t);
128 else
129 lmmp_mul_n_(vinf, a3, b3, s);
130
131 lmmp_toom_interp7_(dst, n, flags, vm2, vm1, v2, vh, s + t, tp);
132
134}
#define scratch
mp_limb_t * mp_ptr
Definition lmmp.h:215
uint64_t mp_size_t
Definition lmmp.h:212
#define lmmp_debug_assert(x)
Definition lmmp.h:387
const mp_limb_t * mp_srcptr
Definition lmmp.h:216
uint64_t mp_limb_t
Definition lmmp.h:211
#define lmmp_param_assert(x)
Definition lmmp.h:398
void lmmp_mul_(mp_ptr dst, mp_srcptr numa, mp_size_t na, mp_srcptr numb, mp_size_t nb)
不等长大数乘法操作 [dst,na+nb] = [numa,na] * [numb,nb]
void lmmp_mul_n_(mp_ptr dst, mp_srcptr numa, mp_srcptr numb, mp_size_t n)
等长大数乘法操作 [dst,2*n] = [numa,n] * [numb,n]
Definition mul.c:99
mp_limb_t lmmp_addshl1_n_(mp_ptr dst, mp_srcptr numa, mp_srcptr numb, mp_size_t n)
加法结合左移1位操作 [dst,n] = [numa,n] + ([numb,n] << 1)
Definition shl.c:56
mp_limb_t lmmp_shl_(mp_ptr dst, mp_srcptr numa, mp_size_t na, mp_size_t shl)
大数左移操作 [dst,na] = [numa,na]<<shl,dst的低shl位填充0
Definition shl.c:9
#define lmmp_inc_1(p, inc)
大数加指定值宏(预期无进位)
Definition lmmpn.h:958
#define b0
#define v0
#define a3
void lmmp_mul_toom44_(mp_ptr restrict dst, mp_srcptr restrict numa, mp_size_t na, mp_srcptr restrict numb, mp_size_t nb)
Definition mul_toom44.c:26
#define b1
#define v2
#define vm1
#define apx
#define vh
#define a2
#define a0
#define tp
#define a1
#define bmx
#define b3
#define b2
#define vinf
#define bpx
#define amx
#define v1
#define vm2
#define SALLOC_TYPE(n, type)
Definition tmp_alloc.h:87
#define TEMP_S_DECL
Definition tmp_alloc.h:76
#define TEMP_S_FREE
Definition tmp_alloc.h:105
toom7_flags
Definition toom_interp.h:27
@ toom7_w1_neg
Definition toom_interp.h:27
@ toom7_w3_neg
Definition toom_interp.h:27
void lmmp_toom_interp7_(mp_ptr dst, mp_size_t n, enum toom7_flags flags, mp_ptr w1, mp_ptr w3, mp_ptr w4, mp_ptr w5, mp_size_t w6n, mp_ptr tp)
Toom插值计算(7点插值):用于Toom-44、Toom-53、Toom-62 乘法算法
int lmmp_toom_eval_dgr3_pm2_(mp_ptr xp2, mp_ptr xm2, mp_srcptr xp, mp_size_t n, mp_size_t x3n, mp_ptr tp)
Toom-3 专用:3次多项式在 x = +2 和 x = -2 处求值 计算 P(+2) 和 P(-2),其中 P(x) 是一个3次多项式(4段系数)。
int lmmp_toom_eval_dgr3_pm1_(mp_ptr xp1, mp_ptr xm1, mp_srcptr xp, mp_size_t n, mp_size_t x3n, mp_ptr tp)
Toom-3 专用:3次多项式在 x = +1 和 x = -1 处求值 计算 P(+1) 和 P(-1),其中 P(x) 是一个3次多项式(4段系数)。