LAMMP 4.1.0
Lamina High-Precision Arithmetic Library
载入中...
搜索中...
未找到
mul_toom42.c
浏览该文件的文档.
1/*
2 * LAMMP - Copyright (C) 2025-2026 HJimmyK(Jericho Knox)
3 * This file is part of lammp, under the GNU LGPL v2 license.
4 * See LICENSE in the project root for the full license text.
5 */
6
7#include "../../include/lammp/lmmpn.h"
8#include "../../include/lammp/impl/toom_interp.h"
9
10/*
11Evaluate in: -1, 0, +1, +2, +inf
12
13 <-s-><--n--><--n--><--n-->
14 |a3-|---a2-|---a1-|---a0-|
15 |-b1-|---b0-|
16 <-t--><--n-->
17
18v0 = a0 * b0 # A(0)*B(0)
19v1 = (a0+ a1+ a2+ a3)*(b0+ b1) # A(1)*B(1) ah <= 3 bh <= 1
20vm1 = (a0- a1+ a2- a3)*(b0- b1) # A(-1)*B(-1) |ah| <= 1 bh = 0
21v2 = (a0+2a1+4a2+8a3)*(b0+2b1) # A(2)*B(2) ah <= 14 bh <= 2
22vinf= a3 * b1 # A(inf)*B(inf)
23*/
24
25void lmmp_mul_toom42_(mp_ptr restrict dst, mp_srcptr restrict numa, mp_size_t na, mp_srcptr restrict numb, mp_size_t nb) {
26 lmmp_param_assert(nb >= 20);
27 lmmp_param_assert(na <= 3 * nb);
28 lmmp_param_assert(5 * na >= 9 * nb);
30 mp_size_t n = na >= 2 * nb ? (na + 3) >> 2 : (nb + 1) >> 1, s = na - 3 * n, t = nb - n;
31 int vm1_neg;
32 mp_limb_t cy, vinf0, am1h;
33 mp_limb_t* restrict tp = SALLOC_TYPE(4 * n + 4, mp_limb_t);
34
35#define a0 numa
36#define a1 (numa + n)
37#define a2 (numa + 2 * n)
38#define a3 (numa + 3 * n)
39#define b0 numb
40#define b1 (numb + n)
41
42#define v0 dst //[dst,2*n]
43#define v1 (dst + 2 * n) //[dst+2*n,2*n+1]
44#define vinf (dst + 4 * n) //[dst+4*n,s+t]
45#define vm1 tp //[tp,2*n+1]
46#define v2 (tp + 2 * n + 2) //[tp+2*n+2,2*n+1]
47
48#define bm1 dst //[dst,n]
49#define am1 (dst + n) //[dst+n,n+1]
50#define ap1 tp //[tp,n+1]
51#define bp1 (tp + n + 1) //[tp+n+1,n+1]
52#define ap2 ap1 // same space
53#define bp2 bp1 // same space
54#define a13 bp1 // temporary use
55
56 // ap1,am1
57 ap1[n] = lmmp_add_n_(ap1, a0, a2, n);
58 a13[n] = lmmp_add_(a13, a1, n, a3, s);
59 vm1_neg = lmmp_cmp_(ap1, a13, n + 1) < 0;
60 if (vm1_neg)
61 lmmp_add_n_sub_n_(ap1, am1, a13, ap1, n + 1);
62 else
63 lmmp_add_n_sub_n_(ap1, am1, ap1, a13, n + 1);
64 am1h = am1[n]; // overlap with v1
65
66 // bp1,bm1
67 if (t == n) {
68 if (lmmp_cmp_(b0, b1, n) < 0) {
69 bp1[n] = lmmp_add_n_sub_n_(bp1, bm1, b1, b0, n) >> 1;
70 vm1_neg ^= 1;
71 } else {
72 bp1[n] = lmmp_add_n_sub_n_(bp1, bm1, b0, b1, n) >> 1;
73 }
74 } else {
75 if (lmmp_zero_q_(b0 + t, n - t) && lmmp_cmp_(b0, b1, t) < 0) {
76 cy = lmmp_add_n_sub_n_(bp1, bm1, b1, b0, t);
77 lmmp_zero(bm1 + t, n - t);
78 vm1_neg ^= 1;
79 } else {
80 cy = lmmp_add_n_sub_n_(bp1, bm1, b0, b1, t);
81 lmmp_sub_1_(bm1 + t, b0 + t, n - t, cy & 1);
82 }
83 bp1[n] = lmmp_add_1_(bp1 + t, b0 + t, n - t, cy >> 1);
84 }
85
86 // vinf=a3*b1
87 if (s > t)
88 lmmp_mul_(vinf, a3, s, b1, t);
89 else
90 lmmp_mul_(vinf, b1, t, a3, s);
91 vinf0 = vinf[0]; // overlap with v1
92 cy = vinf[1]; // overlap with v1
93
94 // v1=ap1*bp1
95 lmmp_mul_n_(v1, ap1, bp1, n + 1);
96 vinf[1] = cy; // restore, since v1[2*n+1]==0.
97
98 // ap2
99 cy = lmmp_addshl1_n_(ap2, a2, a3, s);
100 if (s != n)
101 cy = lmmp_add_1_(ap2 + s, a2 + s, n - s, cy);
102 cy = 2 * cy + lmmp_addshl1_n_(ap2, a1, ap2, n);
103 cy = 2 * cy + lmmp_addshl1_n_(ap2, a0, ap2, n);
104 ap2[n] = cy;
105
106 // bp2=bp1+b1
107 lmmp_add_(bp2, bp1, n + 1, b1, t);
108
109 // v2=ap2*bp2
110 lmmp_mul_n_(v2, ap2, bp2, n + 1);
111
112 // vm1=am1*bm1
113 lmmp_mul_n_(vm1, am1, bm1, n);
114 if (am1h)
115 vm1[2 * n] = lmmp_add_n_(vm1 + n, vm1 + n, bm1, n);
116 else
117 vm1[2 * n] = 0;
118
119 // v0=a0*b0
120 lmmp_mul_n_(v0, a0, b0, n);
121
122 lmmp_toom_interp5_(dst, v2, vm1, n, s + t, vm1_neg, vinf0);
124#undef a0
125#undef a1
126#undef a2
127#undef a3
128#undef b0
129#undef b1
130
131#undef v0
132#undef v1
133#undef vinf
134#undef vm1
135#undef v2
136
137#undef bm1
138#undef am1
139#undef ap1
140#undef bp1
141#undef ap2
142#undef bp2
143#undef a13
144}
145
147 mp_ptr restrict dst,
148 mp_srcptr restrict numa,
149 mp_srcptr restrict numb,
150 mp_size_t n,
151 mp_size_t s,
152 mp_size_t t,
153 mp_ptr restrict _bp1,
154 mp_ptr restrict _bm1,
155 mp_ptr restrict tp
156) {
157 int vm1_neg, flag = 0;
158 mp_limb_t cy, vinf0, am1h;
159
160#define a0 numa
161#define a1 (numa + n)
162#define a2 (numa + 2 * n)
163#define a3 (numa + 3 * n)
164#define b0 numb
165#define b1 (numb + n)
166
167#define v0 dst //[dst,2*n]
168#define v1 (dst + 2 * n) //[dst+2*n,2*n+1]
169#define vinf (dst + 4 * n) //[dst+4*n,s+t]
170#define vm1 tp //[tp,2*n+1]
171#define v2 (tp + 2 * n + 2) //[tp+2*n+2,2*n+1]
172
173#define bm1 _bm1 //[dst,n]
174#define am1 (dst + n) //[dst+n,n+1]
175#define ap1 tp //[tp,n+1]
176#define bp1 _bp1 //[TH._bp1,n+1]
177#define ap2 ap1 // same space
178#define bp2 (tp + n + 1) //[tp+n+1,n+1]
179#define a13 (tp + n + 1) // same space
180
181 // ap1,am1
182 ap1[n] = lmmp_add_n_(ap1, a0, a2, n);
183 a13[n] = lmmp_add_(a13, a1, n, a3, s);
184 vm1_neg = lmmp_cmp_(ap1, a13, n + 1) < 0;
185 if (vm1_neg)
186 lmmp_add_n_sub_n_(ap1, am1, a13, ap1, n + 1);
187 else
188 lmmp_add_n_sub_n_(ap1, am1, ap1, a13, n + 1);
189 am1h = am1[n]; // overlap with v1
190
191 if (t == n) {
192 if (lmmp_cmp_(b0, b1, n) < 0) {
193 bp1[n] = lmmp_add_n_sub_n_(bp1, bm1, b1, b0, n) >> 1;
194 vm1_neg ^= 1;
195 flag = 1;
196 } else {
197 bp1[n] = lmmp_add_n_sub_n_(bp1, bm1, b0, b1, n) >> 1;
198 }
199 } else {
200 if (lmmp_zero_q_(b0 + t, n - t) && lmmp_cmp_(b0, b1, t) < 0) {
201 cy = lmmp_add_n_sub_n_(bp1, bm1, b1, b0, t);
202 lmmp_zero(bm1 + t, n - t);
203 vm1_neg ^= 1;
204 flag = 1;
205 } else {
206 cy = lmmp_add_n_sub_n_(bp1, bm1, b0, b1, t);
207 lmmp_sub_1_(bm1 + t, b0 + t, n - t, cy & 1);
208 }
209 bp1[n] = lmmp_add_1_(bp1 + t, b0 + t, n - t, cy >> 1);
210 }
211
212 // vinf=a3*b1
213 if (s > t)
214 lmmp_mul_(vinf, a3, s, b1, t);
215 else
216 lmmp_mul_(vinf, b1, t, a3, s);
217 vinf0 = vinf[0]; // overlap with v1
218 cy = vinf[1]; // overlap with v1
219
220 // v1=ap1*bp1
221 lmmp_mul_n_(v1, ap1, bp1, n + 1);
222 vinf[1] = cy; // restore, since v1[2*n+1]==0.
223
224 // ap2
225 cy = lmmp_addshl1_n_(ap2, a2, a3, s);
226 if (s != n)
227 cy = lmmp_add_1_(ap2 + s, a2 + s, n - s, cy);
228 cy = 2 * cy + lmmp_addshl1_n_(ap2, a1, ap2, n);
229 cy = 2 * cy + lmmp_addshl1_n_(ap2, a0, ap2, n);
230 ap2[n] = cy;
231
232 // bp2=bp1+b1
233 lmmp_add_(bp2, bp1, n + 1, b1, t);
234
235 // v2=ap2*bp2
236 lmmp_mul_n_(v2, ap2, bp2, n + 1);
237
238 // vm1=am1*bm1
239 lmmp_mul_n_(vm1, am1, bm1, n);
240 if (am1h)
241 vm1[2 * n] = lmmp_add_n_(vm1 + n, vm1 + n, bm1, n);
242 else
243 vm1[2 * n] = 0;
244
245 // v0=a0*b0
246 lmmp_mul_n_(v0, a0, b0, n);
247
248 lmmp_toom_interp5_(dst, v2, vm1, n, s + t, vm1_neg, vinf0);
249 return flag;
250#undef a0
251#undef a1
252#undef a2
253#undef a3
254#undef b0
255#undef b1
256
257#undef v0
258#undef v1
259#undef vinf
260#undef vm1
261#undef v2
262
263#undef bm1
264#undef am1
265#undef ap1
266#undef bp1
267#undef ap2
268#undef bp2
269#undef a13
270}
271
273 mp_ptr restrict dst,
274 mp_srcptr restrict numa,
275 mp_srcptr restrict numb,
276 mp_size_t n,
277 mp_size_t s,
278 mp_size_t t,
279 mp_srcptr restrict _bp1,
280 mp_srcptr restrict _bm1,
281 mp_ptr restrict tp,
282 int flag
283) {
284 int vm1_neg;
285 mp_limb_t cy, vinf0, am1h;
286
287#define a0 numa
288#define a1 (numa + n)
289#define a2 (numa + 2 * n)
290#define a3 (numa + 3 * n)
291#define b0 numb
292#define b1 (numb + n)
293
294#define v0 dst //[dst,2*n]
295#define v1 (dst + 2 * n) //[dst+2*n,2*n+1]
296#define vinf (dst + 4 * n) //[dst+4*n,s+t]
297#define vm1 tp //[tp,2*n+1]
298#define v2 (tp + 2 * n + 2) //[tp+2*n+2,2*n+1]
299
300#define bm1 _bm1 //[dst,n]
301#define am1 (dst + n) //[dst+n,n+1]
302#define ap1 tp //[tp,n+1]
303#define bp1 _bp1 //[TH._bp1,n+1]
304#define ap2 ap1 // same space
305#define bp2 (tp + n + 1) //[tp+n+1,n+1]
306#define a13 (tp + n + 1) // same space
307
308 // ap1,am1
309 ap1[n] = lmmp_add_n_(ap1, a0, a2, n);
310 a13[n] = lmmp_add_(a13, a1, n, a3, s);
311 vm1_neg = lmmp_cmp_(ap1, a13, n + 1) < 0;
312 if (vm1_neg)
313 lmmp_add_n_sub_n_(ap1, am1, a13, ap1, n + 1);
314 else
315 lmmp_add_n_sub_n_(ap1, am1, ap1, a13, n + 1);
316 am1h = am1[n]; // overlap with v1
317
318 if (flag)
319 vm1_neg ^= 1;
320
321 // vinf=a3*b1
322 if (s > t)
323 lmmp_mul_(vinf, a3, s, b1, t);
324 else
325 lmmp_mul_(vinf, b1, t, a3, s);
326 vinf0 = vinf[0]; // overlap with v1
327 cy = vinf[1]; // overlap with v1
328
329 // v1=ap1*bp1
330 lmmp_mul_n_(v1, ap1, bp1, n + 1);
331 vinf[1] = cy; // restore, since v1[2*n+1]==0.
332
333 // ap2
334 cy = lmmp_addshl1_n_(ap2, a2, a3, s);
335 if (s != n)
336 cy = lmmp_add_1_(ap2 + s, a2 + s, n - s, cy);
337 cy = 2 * cy + lmmp_addshl1_n_(ap2, a1, ap2, n);
338 cy = 2 * cy + lmmp_addshl1_n_(ap2, a0, ap2, n);
339 ap2[n] = cy;
340
341 // bp2=bp1+b1
342 lmmp_add_(bp2, bp1, n + 1, b1, t);
343
344 // v2=ap2*bp2
345 lmmp_mul_n_(v2, ap2, bp2, n + 1);
346
347 // vm1=am1*bm1
348 lmmp_mul_n_(vm1, am1, bm1, n);
349 if (am1h)
350 vm1[2 * n] = lmmp_add_n_(vm1 + n, vm1 + n, bm1, n);
351 else
352 vm1[2 * n] = 0;
353
354 // v0=a0*b0
355 lmmp_mul_n_(v0, a0, b0, n);
356
357 lmmp_toom_interp5_(dst, v2, vm1, n, s + t, vm1_neg, vinf0);
358}
359
360void lmmp_mul_toom42_unbalance_(mp_ptr restrict dst, mp_srcptr restrict numa, mp_size_t na, mp_srcptr restrict numb, mp_size_t nb) {
361 lmmp_param_assert(na >= 3 * nb);
362 lmmp_param_assert(nb > 20);
363
365 mp_limb_t* restrict ws = SALLOC_TYPE(nb, mp_limb_t);
366 mp_size_t n = (2 * nb + 3) >> 2, s = 2 * nb - 3 * n, t = nb - n;
367 mp_ptr restrict tp = SALLOC_TYPE(4 * n + 4, mp_limb_t);
368 mp_ptr restrict _bp1 = SALLOC_TYPE(2 * n + 1, mp_limb_t);
369 mp_ptr restrict _bm1 = _bp1 + n + 1;
370 int flag = lmmp_mul_toom42_cache_init_(dst, numa, numb, n, s, t, _bp1, _bm1, tp);
371 dst += 2 * nb;
372 numa += 2 * nb;
373 na -= 2 * nb;
374 lmmp_copy(ws, dst, nb);
375 while (2 * na >= 5 * nb) {
376 lmmp_mul_toom42_cache_(dst, numa, numb, n, s, t, _bp1, _bm1, tp, flag);
377 if (lmmp_add_n_(dst, dst, ws, nb))
378 lmmp_inc(dst + nb);
379 dst += 2 * nb;
380 numa += 2 * nb;
381 na -= 2 * nb;
382 lmmp_copy(ws, dst, nb);
383 }
384 // 0.5 nb <= na < 2.5 nb
385 if (na >= nb)
386 lmmp_mul_(dst, numa, na, numb, nb);
387 else
388 lmmp_mul_(dst, numb, nb, numa, na);
389 if (lmmp_add_n_(dst, dst, ws, nb))
390 lmmp_inc(dst + nb);
392}
mp_limb_t * mp_ptr
Definition lmmp.h:215
#define lmmp_copy(dst, src, n)
Definition lmmp.h:364
#define lmmp_zero(dst, n)
Definition lmmp.h:366
uint64_t mp_size_t
Definition lmmp.h:212
const mp_limb_t * mp_srcptr
Definition lmmp.h:216
uint64_t mp_limb_t
Definition lmmp.h:211
#define lmmp_param_assert(x)
Definition lmmp.h:398
static mp_limb_t lmmp_add_(mp_ptr dst, mp_srcptr numa, mp_size_t na, mp_srcptr numb, mp_size_t nb)
大数加法静态内联函数 [dst,na]=[numa,na]+[numb,nb]
Definition lmmpn.h:1058
static int lmmp_cmp_(mp_srcptr numa, mp_srcptr numb, mp_size_t n)
大数比较函数(内联)
Definition lmmpn.h:1004
static mp_limb_t lmmp_add_1_(mp_ptr dst, mp_srcptr numa, mp_size_t na, mp_limb_t x)
大数加单精度数静态内联函数 [dst,na]=[numa,na]+x
Definition lmmpn.h:1111
#define lmmp_inc(p)
大数加1宏(预期无进位)
Definition lmmpn.h:946
void lmmp_mul_(mp_ptr dst, mp_srcptr numa, mp_size_t na, mp_srcptr numb, mp_size_t nb)
不等长大数乘法操作 [dst,na+nb] = [numa,na] * [numb,nb]
void lmmp_mul_n_(mp_ptr dst, mp_srcptr numa, mp_srcptr numb, mp_size_t n)
等长大数乘法操作 [dst,2*n] = [numa,n] * [numb,n]
Definition mul.c:99
mp_limb_t lmmp_addshl1_n_(mp_ptr dst, mp_srcptr numa, mp_srcptr numb, mp_size_t n)
加法结合左移1位操作 [dst,n] = [numa,n] + ([numb,n] << 1)
Definition shl.c:56
mp_limb_t lmmp_add_n_sub_n_(mp_ptr dsta, mp_ptr dstb, mp_srcptr numa, mp_srcptr numb, mp_size_t n)
同时执行n位加法和减法 ([dsta,n],[dstb,n]) = ([numa,n]+[numb,n],[numa,n]-[numb,n])
Definition add_n_sub_n.c:10
static mp_limb_t lmmp_sub_1_(mp_ptr dst, mp_srcptr numa, mp_size_t na, mp_limb_t x)
大数减单精度数静态内联函数 [dst,na]=[numa,na]-x
Definition lmmpn.h:1122
mp_limb_t lmmp_add_n_(mp_ptr dst, mp_srcptr numa, mp_srcptr numb, mp_size_t n)
无进位的n位加法 [dst,n] = [numa,n] + [numb,n]
Definition add_n.c:71
static int lmmp_zero_q_(mp_srcptr p, mp_size_t n)
大数判零函数(内联)
Definition lmmpn.h:1027
static void lmmp_mul_toom42_cache_(mp_ptr restrict dst, mp_srcptr restrict numa, mp_srcptr restrict numb, mp_size_t n, mp_size_t s, mp_size_t t, mp_srcptr restrict _bp1, mp_srcptr restrict _bm1, mp_ptr restrict tp, int flag)
Definition mul_toom42.c:272
#define ap2
#define b0
#define v0
#define a3
#define b1
#define am1
void lmmp_mul_toom42_(mp_ptr restrict dst, mp_srcptr restrict numa, mp_size_t na, mp_srcptr restrict numb, mp_size_t nb)
Definition mul_toom42.c:25
static int lmmp_mul_toom42_cache_init_(mp_ptr restrict dst, mp_srcptr restrict numa, mp_srcptr restrict numb, mp_size_t n, mp_size_t s, mp_size_t t, mp_ptr restrict _bp1, mp_ptr restrict _bm1, mp_ptr restrict tp)
Definition mul_toom42.c:146
#define ap1
#define v2
#define bp1
#define vm1
#define a13
#define bm1
#define bp2
void lmmp_mul_toom42_unbalance_(mp_ptr restrict dst, mp_srcptr restrict numa, mp_size_t na, mp_srcptr restrict numb, mp_size_t nb)
Definition mul_toom42.c:360
#define a2
#define a0
#define a1
#define vinf
#define v1
#define tp
#define SALLOC_TYPE(n, type)
Definition tmp_alloc.h:87
#define TEMP_S_DECL
Definition tmp_alloc.h:76
#define TEMP_S_FREE
Definition tmp_alloc.h:105
void lmmp_toom_interp5_(mp_ptr dst, mp_ptr v2, mp_ptr vm1, mp_size_t n, mp_size_t spt, int vm1_neg, mp_limb_t vinf0)
Toom插值计算(5点插值),用于Toom-33和Toom-42乘法算法