LAMMP 4.1.0
Lamina High-Precision Arithmetic Library
载入中...
搜索中...
未找到
mat22_mul.c
浏览该文件的文档.
1/*
2 * LAMMP - Copyright (C) 2025-2026 HJimmyK(Jericho Knox)
3 * This file is part of lammp, under the GNU LGPL v2 license.
4 * See LICENSE in the project root for the full license text.
5 */
6
7#include "../../../include/lammp/impl/signed.h"
8#include "../../../include/lammp/impl/tmp_alloc.h"
9#include "../../../include/lammp/lmmpn.h"
10#include "../../../include/lammp/impl/mparam.h"
11#include "../../../include/lammp/matrix.h"
12
14 lmmp_mat22_t* dst,
15 const lmmp_mat22_t* matA,
16 const lmmp_mat22_t* matB,
17 mp_size_t* tn,
18 mp_size_t* maxa
19) {
20 lmmp_param_assert(matA!= NULL && matB!= NULL && dst!= NULL);
21 lmmp_param_assert(tn != NULL);
22 if (matA == matB) {
23 mp_ssize_t A00 = LMMP_ABS(matA->n00);
24 mp_ssize_t A01 = LMMP_ABS(matA->n01);
25 mp_ssize_t A10 = LMMP_ABS(matA->n10);
26 mp_ssize_t A11 = LMMP_ABS(matA->n11);
29 dst->n00 = LMMP_MAX((A00 + A00), (A01 + A10));
30 dst->n01 = LMMP_MAX((A00 + A01), (A01 + A11));
31 dst->n10 = LMMP_MAX((A10 + A00), (A11 + A10));
32 dst->n11 = LMMP_MAX((A10 + A01), (A11 + A11));
33 *tn = LMMP_MAX(LMMP_MAX(LMMP_MAX(dst->n00, dst->n01), dst->n10), dst->n11) + 1;
34 ++(dst->n00);
35 ++(dst->n01);
36 ++(dst->n10);
37 ++(dst->n11);
38 return 0;
39 } else {
40 *maxa = LMMP_MAX(LMMP_MAX(LMMP_MAX(A00, A01), A10), A11) + 1;
41 *tn = (*maxa << 1) + 1;
42 dst->n00 = *tn;
43 dst->n01 = *tn;
44 dst->n10 = *tn;
45 dst->n11 = *tn;
46 return 1;
47 }
48 } else {
49 mp_ssize_t A00 = LMMP_ABS(matA->n00);
50 mp_ssize_t A01 = LMMP_ABS(matA->n01);
51 mp_ssize_t A10 = LMMP_ABS(matA->n10);
52 mp_ssize_t A11 = LMMP_ABS(matA->n11);
53 mp_ssize_t B00 = LMMP_ABS(matB->n00);
54 mp_ssize_t B01 = LMMP_ABS(matB->n01);
55 mp_ssize_t B10 = LMMP_ABS(matB->n10);
56 mp_ssize_t B11 = LMMP_ABS(matB->n11);
61 dst->n00 = LMMP_MAX((A00 + B00), (A01 + B10));
62 dst->n01 = LMMP_MAX((A00 + B01), (A01 + B11));
63 dst->n10 = LMMP_MAX((A10 + B00), (A11 + B10));
64 dst->n11 = LMMP_MAX((A10 + B01), (A11 + B11));
65 *tn = LMMP_MAX(LMMP_MAX(LMMP_MAX(dst->n00, dst->n01), dst->n10), dst->n11);
66 ++(dst->n00);
67 ++(dst->n01);
68 ++(dst->n10);
69 ++(dst->n11);
70 return 0;
71 } else {
72 *maxa = LMMP_MAX(LMMP_MAX(LMMP_MAX(A00, A01), A10), A11) + 1;
73 *tn = *maxa + LMMP_MAX(LMMP_MAX(LMMP_MAX(B00, B01), B10), B11) + 1;
74 dst->n00 = *tn;
75 dst->n01 = *tn;
76 dst->n10 = *tn;
77 dst->n11 = *tn;
78 return 1;
79 }
80 }
81}
82
84 lmmp_mat22_t* dst,
85 const lmmp_mat22_t* matA,
86 const lmmp_mat22_t* matB,
87 mp_ptr tp,
88 mp_size_t tn
89) {
90 lmmp_param_assert(matA != NULL && matB != NULL && dst != NULL);
91 lmmp_param_assert(tn > 0);
92 if (matA == matB) {
93 lmmp_mat22_sqr_basecase_(dst, matA, tp, tn);
94 return;
95 }
97 if (tp == NULL)
98 tp = TALLOC_TYPE(tn * 2, mp_limb_t);
99#define p1 tp
100#define p2 tp + tn
101 mp_ssize_t pn1, pn2;
102 pn1 = lmmp_mul_signed_(p1, matA->a00, matA->n00, matB->a00, matB->n00);
103 pn2 = lmmp_mul_signed_(p2, matA->a01, matA->n01, matB->a10, matB->n10);
104 dst->n00 = lmmp_add_signed_(dst->a00, p1, pn1, p2, pn2);
105 pn1 = lmmp_mul_signed_(p1, matA->a00, matA->n00, matB->a01, matB->n01);
106 pn2 = lmmp_mul_signed_(p2, matA->a01, matA->n01, matB->a11, matB->n11);
107 dst->n01 = lmmp_add_signed_(dst->a01, p1, pn1, p2, pn2);
108 pn1 = lmmp_mul_signed_(p1, matA->a10, matA->n10, matB->a00, matB->n00);
109 pn2 = lmmp_mul_signed_(p2, matA->a11, matA->n11, matB->a10, matB->n10);
110 dst->n10 = lmmp_add_signed_(dst->a10, p1, pn1, p2, pn2);
111 pn1 = lmmp_mul_signed_(p1, matA->a10, matA->n10, matB->a01, matB->n01);
112 pn2 = lmmp_mul_signed_(p2, matA->a11, matA->n11, matB->a11, matB->n11);
113 dst->n11 = lmmp_add_signed_(dst->a11, p1, pn1, p2, pn2);
114#undef p1
115#undef p2
116 TEMP_FREE;
117}
118
120 lmmp_mat22_t* dst,
121 const lmmp_mat22_t* matA,
122 mp_ptr tp,
123 mp_size_t tn
124) {
125 TEMP_DECL;
126 if (tp == NULL)
127 tp = TALLOC_TYPE(tn * 2, mp_limb_t);
128#define p1 tp
129#define p2 tp + tn
130 mp_ssize_t pn1, pn2;
131 pn1 = lmmp_sqr_signed_(p1, matA->a00, matA->n00);
132 pn2 = lmmp_mul_signed_(p2, matA->a01, matA->n01, matA->a10, matA->n10);
133 dst->n00 = lmmp_add_signed_(dst->a00, p1, pn1, p2, pn2);
134 pn1 = lmmp_mul_signed_(p1, matA->a00, matA->n00, matA->a01, matA->n01);
135 pn2 = lmmp_mul_signed_(p2, matA->a01, matA->n01, matA->a11, matA->n11);
136 dst->n01 = lmmp_add_signed_(dst->a01, p1, pn1, p2, pn2);
137 pn1 = lmmp_mul_signed_(p1, matA->a10, matA->n10, matA->a00, matA->n00);
138 pn2 = lmmp_mul_signed_(p2, matA->a11, matA->n11, matA->a10, matA->n10);
139 dst->n10 = lmmp_add_signed_(dst->a10, p1, pn1, p2, pn2);
140 pn1 = lmmp_mul_signed_(p1, matA->a10, matA->n10, matA->a01, matA->n01);
141 pn2 = lmmp_sqr_signed_(p2, matA->a11, matA->n11);
142 dst->n11 = lmmp_add_signed_(dst->a11, p1, pn1, p2, pn2);
143#undef p1
144#undef p2
145 TEMP_FREE;
146}
147
148/*
149 * Strassen 2x2 矩阵乘法的 Winograd 变体
150 *
151 * 输入矩阵:
152 * A = | A11 A12 |
153 * | A21 A22 |
154 * B = | B11 B12 |
155 * | B21 B22 |
156 *
157 * 输出矩阵 C = A * B:
158 * C = | C11 C12 |
159 * | C21 C22 |
160 *
161 *
162 * s1 = A22 + A12
163 * s2 = A22 - A21
164 * s3 = s2 + A12 = A22 - A21 + A12
165 * s4 = s3 - A11 = A22 - A21 + A12 - A11
166 *
167 * t1 = B22 + B12
168 * t2 = B22 - B21
169 * t3 = t2 + B12 = B22 - B21 + B12
170 * t4 = t3 - B11 = B22 - B21 + B12 - B11
171 *
172 * 7 个 Strassen 乘积项
173 * p1 = s1 * t1 = (A22 + A12 ) * (B22 + B12 )
174 * p2 = s2 * t2 = (A22 - A21 ) * (B22 - B21 )
175 * p3 = s3 * t3 = (A22 - A21 + A12) * (B22 - B21 + B12)
176 * p4 = A11 * B11
177 * p5 = A12 * B21
178 * p6 = s4 * B12
179 * p7 = A21 * t4
180 *
181 * U1 = p3 + p5
182 * U2 = p1 - U1
183 * U3 = U1 - p2
184 *
185 * result:
186 * C11 = p4 + p5
187 * C12 = U3 - p6
188 * C21 = U2 - p7
189 * C22 = p2 + U2
190 *
191 * 平方版本(A*A):所有乘法替换为平方/自身相乘,流程一致。
192 */
193
195 lmmp_mat22_t* dst,
196 const lmmp_mat22_t* matA,
197 const lmmp_mat22_t* matB,
198 mp_ptr tp,
199 mp_size_t tn,
200 mp_size_t maxa
201) {
202 lmmp_param_assert(matA != NULL && matB != NULL && dst != NULL);
203 lmmp_param_assert(tn > 0 && maxa > 0);
204 if (matA == matB) {
205 lmmp_mat22_sqr_strassen_(dst, matA, tp, tn);
206 return;
207 }
209 ++tn;
210 if (tp == NULL)
211 tp = BALLOC_TYPE(tn * 7, mp_limb_t);
212
213#define A11 (matA->a00)
214#define A12 (matA->a01)
215#define A21 (matA->a10)
216#define A22 (matA->a11)
217#define B11 (matB->a00)
218#define B12 (matB->a01)
219#define B21 (matB->a10)
220#define B22 (matB->a11)
221#define A11n (matA->n00)
222#define A12n (matA->n01)
223#define A21n (matA->n10)
224#define A22n (matA->n11)
225#define B11n (matB->n00)
226#define B12n (matB->n01)
227#define B21n (matB->n10)
228#define B22n (matB->n11)
229
230#define s1 (dst->a00)
231#define s2 (dst->a01)
232#define s3 (dst->a10)
233#define s4 (dst->a11)
234#define t1 (dst->a00 + maxa)
235#define t2 (dst->a01 + maxa)
236#define t3 (dst->a10 + maxa)
237#define t4 (dst->a11 + maxa)
238#define p1 (tp)
239#define p2 (tp + tn)
240#define p3 (tp + 2 * tn)
241#define p4 (tp + 3 * tn)
242#define p5 (tp + 4 * tn)
243#define p6 (tp + 5 * tn)
244#define p7 (tp + 6 * tn)
245 mp_ssize_t n1, n2, n3, n4, n5, n6, n7, n8;
247 n2 = lmmp_add_signed_(s2, A22, A22n, A21, -A21n);
248 n3 = lmmp_add_signed_(s3, s2, n2, A12, A12n);
249 n4 = lmmp_add_signed_(s4, s3, n3, A11, -A11n);
251 n6 = lmmp_add_signed_(t2, B22, B22n, B21, -B21n);
252 n7 = lmmp_add_signed_(t3, t2, n6, B12, B12n);
253 n8 = lmmp_add_signed_(t4, t3, n7, B11, -B11n);
254
255 n1 = lmmp_mul_signed_(p1, s1, n1, t1, n5);
256 n5 = lmmp_mul_signed_(p2, s2, n2, t2, n6);
257 n2 = lmmp_mul_signed_(p3, s3, n3, t3, n7);
260 n3 = lmmp_mul_signed_(p6, s4, n4, B12, B12n);
261 n4 = lmmp_mul_signed_(p7, A21, A21n, t4, n8);
262
263#undef s1
264#undef s2
265#undef s3
266#undef s4
267#undef t1
268#undef t2
269#undef t3
270#undef t4
271
272#define p1n n1
273#define p2n n5
274#define p3n n2
275#define p4n n7
276#define p5n n6
277#define p6n n3
278#define p7n n4
279
280#undef A11
281#undef A12
282#undef A21
283#undef A22
284#undef B11
285#undef B12
286#undef B21
287#undef B22
288#undef A11n
289#undef A12n
290#undef A21n
291#undef A22n
292#undef B11n
293#undef B12n
294#undef B21n
295#undef B22n
296
297#define C11 (dst->a00)
298#define C12 (dst->a01)
299#define C21 (dst->a10)
300#define C22 (dst->a11)
301#define C11n (dst->n00)
302#define C12n (dst->n01)
303#define C21n (dst->n10)
304#define C22n (dst->n11)
305
307#define U1 p5 // U1 = p3 + p5
308#define U2 p1 // U2 = p1 - U1
309#define U3 U1 // U3 = U1 - p2
310#define U1n p5n
311#define U2n p1n
312#define U3n n8
316
321
322#undef C11
323#undef C12
324#undef C21
325#undef C22
326#undef C11n
327#undef C12n
328#undef C21n
329#undef C22n
330#undef U1
331#undef U2
332#undef U3
333#undef U1n
334#undef U2n
335#undef U3n
336
337#undef p1
338#undef p2
339#undef p3
340#undef p4
341#undef p5
342#undef p6
343#undef p7
344}
345
347 lmmp_param_assert(mat != NULL && dst != NULL);
349 ++tn;
350 if (tp == NULL)
351 tp = BALLOC_TYPE(tn * 7, mp_limb_t);
352
353#define A11 (mat->a00)
354#define A12 (mat->a01)
355#define A21 (mat->a10)
356#define A22 (mat->a11)
357#define A11n (mat->n00)
358#define A12n (mat->n01)
359#define A21n (mat->n10)
360#define A22n (mat->n11)
361
362#define s1 (dst->a00)
363#define s2 (dst->a01)
364#define s3 (dst->a10)
365#define s4 (dst->a11)
366#define p1 (tp)
367#define p2 (tp + tn)
368#define p3 (tp + 2 * tn)
369#define p4 (tp + 3 * tn)
370#define p5 (tp + 4 * tn)
371#define p6 (tp + 5 * tn)
372#define p7 (tp + 6 * tn)
373 mp_ssize_t n1, n2, n3, n4, n5, n6, n7, n8;
375 n2 = lmmp_add_signed_(s2, A22, A22n, A21, -A21n);
376 n3 = lmmp_add_signed_(s3, s2, n2, A12, A12n);
377 n4 = lmmp_add_signed_(s4, s3, n3, A11, -A11n);
378
379 n1 = lmmp_sqr_signed_(p1, s1, n1);
380 n5 = lmmp_sqr_signed_(p2, s2, n2);
381 n2 = lmmp_sqr_signed_(p3, s3, n3);
382 n7 = lmmp_sqr_signed_(p4, A11, A11n);
384 n3 = lmmp_mul_signed_(p6, s4, n4, A12, A12n);
385 n4 = lmmp_mul_signed_(p7, A21, A21n, s4, n4);
386
387#undef s1
388#undef s2
389#undef s3
390#undef s4
391
392#define p1n n1
393#define p2n n5
394#define p3n n2
395#define p4n n7
396#define p5n n6
397#define p6n n3
398#define p7n n4
399
400#undef A11
401#undef A12
402#undef A21
403#undef A22
404#undef A11n
405#undef A12n
406#undef A21n
407#undef A22n
408
409#define C11 (dst->a00)
410#define C12 (dst->a01)
411#define C21 (dst->a10)
412#define C22 (dst->a11)
413#define C11n (dst->n00)
414#define C12n (dst->n01)
415#define C21n (dst->n10)
416#define C22n (dst->n11)
417
419#define U1 p5 // U1 = p3 + p5
420#define U2 p1 // U2 = p1 - U1
421#define U3 U1 // U3 = U1 - p2
422#define U1n p5n
423#define U2n p1n
424#define U3n n8
428
433
434#undef C11
435#undef C12
436#undef C21
437#undef C22
438#undef C11n
439#undef C12n
440#undef C21n
441#undef C22n
442#undef U1
443#undef U2
444#undef U3
445#undef U1n
446#undef U2n
447#undef U3n
448
449#undef p1
450#undef p2
451#undef p3
452#undef p4
453#undef p5
454#undef p6
455#undef p7
456}
mp_limb_t * mp_ptr
Definition lmmp.h:215
uint64_t mp_size_t
Definition lmmp.h:212
int64_t mp_ssize_t
Definition lmmp.h:214
#define LMMP_MAX(h, i)
Definition lmmp.h:350
uint64_t mp_limb_t
Definition lmmp.h:211
#define LMMP_ABS(x)
Definition lmmp.h:346
#define lmmp_param_assert(x)
Definition lmmp.h:398
#define s1
#define C11n
#define s4
#define C22n
#define U2n
#define s3
#define p6n
#define s2
#define C12n
#define t4
#define p6
#define A22n
void lmmp_mat22_mul_strassen_(lmmp_mat22_t *dst, const lmmp_mat22_t *matA, const lmmp_mat22_t *matB, mp_ptr tp, mp_size_t tn, mp_size_t maxa)
计算(稠密)2x2矩阵和(稠密)2x2矩阵的乘积(STRASSEN算法)
Definition mat22_mul.c:194
#define B12n
int lmmp_mat22_mul_size_(lmmp_mat22_t *dst, const lmmp_mat22_t *matA, const lmmp_mat22_t *matB, mp_size_t *tn, mp_size_t *maxa)
计算2x2矩阵和2x2矩阵的乘积需要分配的内存
Definition mat22_mul.c:13
void lmmp_mat22_mul_basecase_(lmmp_mat22_t *dst, const lmmp_mat22_t *matA, const lmmp_mat22_t *matB, mp_ptr tp, mp_size_t tn)
计算2x2矩阵和2x2矩阵的乘积
Definition mat22_mul.c:83
#define B21
#define t1
#define B21n
#define A22
#define B12
#define t3
void lmmp_mat22_sqr_strassen_(lmmp_mat22_t *dst, const lmmp_mat22_t *mat, mp_ptr tp, mp_size_t tn)
计算(稠密)2x2矩阵平方(STRASSEN算法)
Definition mat22_mul.c:346
#define U3n
#define A11n
#define A12
#define C12
#define U3
#define p1n
#define B22n
#define p2
#define A21n
#define p4n
#define A21
#define t2
#define A11
#define U1
#define C21n
#define C22
void lmmp_mat22_sqr_basecase_(lmmp_mat22_t *dst, const lmmp_mat22_t *matA, mp_ptr tp, mp_size_t tn)
计算2x2矩阵平方
Definition mat22_mul.c:119
#define A12n
#define B22
#define p3
#define B11
#define p3n
#define C21
#define C11
#define p5n
#define p1
#define p7
#define U1n
#define p7n
#define U2
#define p2n
#define p5
#define p4
#define B11n
mp_ptr a01
Definition matrix.h:60
mp_ssize_t n10
Definition matrix.h:65
mp_ptr a11
Definition matrix.h:62
mp_ssize_t n11
Definition matrix.h:66
mp_ssize_t n01
Definition matrix.h:64
mp_ssize_t n00
Definition matrix.h:63
mp_ptr a00
Definition matrix.h:59
mp_ptr a10
Definition matrix.h:61
#define MAT22_SQR_STRASSEN_THRESHOLD
Definition mparam.h:101
#define MAT22_MUL_STRASSEN_THRESHOLD
Definition mparam.h:98
#define tp
static mp_ssize_t lmmp_sqr_signed_(mp_ptr dst, mp_srcptr numa, mp_ssize_t na)
计算带符号数的平方
Definition signed.h:171
static mp_ssize_t lmmp_add_signed_(mp_ptr dst, mp_srcptr numa, mp_ssize_t na, mp_srcptr numb, mp_ssize_t nb)
计算带符号数的加法
Definition signed.h:38
static mp_ssize_t lmmp_mul_signed_(mp_ptr dst, mp_srcptr numa, mp_ssize_t na, mp_srcptr numb, mp_ssize_t nb)
计算带符号数的乘法
Definition signed.h:146
#define TEMP_DECL
Definition tmp_alloc.h:72
#define TEMP_FREE
Definition tmp_alloc.h:93
#define TALLOC_TYPE(n, type)
Definition tmp_alloc.h:91
#define TEMP_B_DECL
Definition tmp_alloc.h:75
#define BALLOC_TYPE(n, type)
Definition tmp_alloc.h:89
#define TEMP_B_FREE
Definition tmp_alloc.h:100