changelog shortlog tags changeset browse all files revisions annotate raw

sage/matrix/strassen.pyx

changeset 2857: 8f618bb5ffc8
parent:0cada7268ad6
child:f28a998cee28
author: Robert Bradshaw <robertwb@math.washington.edu>
date: Wed Feb 07 02:03:37 2007 -0800 (21 months ago)
permissions: -rw-r--r--
description: re-wrote multi-modular class to consolidate (for better algorithms on large number of basis elements), workaround for bug on sage.math with matrix_window_modn_dense, finish implementing multi-modular mult, echelon, consolidate memory allocation in strassen, etc.
1"""
2Generic Asymptotically Fast Strassen Algorithms
3
4SAGE implements asymptotically fast echelon form and matrix multiplication algorithms.
5"""
6
7################################################################################
8# Copyright (C) 2005, 2006 William Stein <wstein@gmail.com>
9#
10# Distributed under the terms of the GNU General Public License (GPL), version 2.
11# The full text of the GPL is available at:
12#
13# http://www.gnu.org/licenses/
14################################################################################
15
16
17from matrix_window cimport MatrixWindow
18
19include "../ext/interrupt.pxi"
20
21
22def strassen_window_multiply(C, A,B, cutoff):
23 """
24 Multiplies the submatrices specified by A and B, places result
25 in C. Assumes that A and B have compatible dimensions to be
26 multiplied, and that C is the correct size to receive the
27 product, and that they are all defined over the same ring.
28
29 Uses strassen multiplication at high levels and then uses MatrixWindow
30 methods at low levels.
31 EXAMPLES:
32 The following matrix dimensions are chosen especially to exercise the
33 eight possible parity combinations that ocould ccur while subdividing
34 the matrix in the strassen recursion. The base case in both cases will
35 be a (4x5) matrix times a (5x6) matrix.
36
37 sage: A = MatrixSpace(Integers(2^65), 64, 83).random_element()
38 sage: B = MatrixSpace(Integers(2^65), 83, 101).random_element()
39 sage: A._multiply_classical(B) == A._multiply_strassen(B, 3)
40 True
41
42 AUTHOR: David Harvey
43 """
44 strassen_window_multiply_c(C, A, B, cutoff)
45
46cdef strassen_window_multiply_c(MatrixWindow C, MatrixWindow A,
47 MatrixWindow B, Py_ssize_t cutoff):
48 # todo -- I'm not sure how to interpret "cutoff". Should it be...
49 # (a) the minimum side length of the matrices (currently implemented below)
50 # (b) the maximum side length of the matrices
51 # (c) the total number of entries being multiplied
52 # (d) something else entirely?
53
54 cdef Py_ssize_t A_nrows, A_ncols, B_ncols
55 A_nrows = A._nrows
56 A_ncols = A._ncols # this should also be the number of rows of B
57 B_ncols = B._ncols
58
59 if (A_nrows <= cutoff) or (A_ncols <= cutoff) or (B_ncols <= cutoff):
60 # note: this code is only reached if the TOP level is already beneath
61 # the cutoff. In a typical large multiplication, the base case is
62 # handled directly (see below).
63 C.set_to_prod(A, B)
64 return
65
66 # Construct windows for the four quadrants of each matrix.
67 # Note that if the side lengths are odd we're ignoring the
68 # final row/column for the moment.
69
70 cdef Py_ssize_t A_sub_nrows, A_sub_ncols, B_sub_ncols
71 A_sub_nrows = A_nrows >> 1
72 A_sub_ncols = A_ncols >> 1 # this is also like B_sub_nrows
73 B_sub_ncols = B_ncols >> 1
74
75 cdef MatrixWindow A00, A01, A10, A11, B00, B01, B10, B11
76 A00 = A.matrix_window(0, 0, A_sub_nrows, A_sub_ncols)
77 A01 = A.matrix_window(0, A_sub_ncols, A_sub_nrows, A_sub_ncols)
78 A10 = A.matrix_window(A_sub_nrows, 0, A_sub_nrows, A_sub_ncols)
79 A11 = A.matrix_window(A_sub_nrows, A_sub_ncols, A_sub_nrows, A_sub_ncols)
80 B00 = B.matrix_window(0, 0, A_sub_ncols, B_sub_ncols)
81 B01 = B.matrix_window(0, B_sub_ncols, A_sub_ncols, B_sub_ncols)
82 B10 = B.matrix_window(A_sub_ncols, 0, A_sub_ncols, B_sub_ncols)
83 B11 = B.matrix_window(A_sub_ncols, B_sub_ncols, A_sub_ncols, B_sub_ncols)
84
85 # Allocate temp space.
86
87 cdef MatrixWindow S0, S1, S2, S3, T0, T1 ,T2, T3, Q0, Q1, Q2
88 cdef MatrixWindow tmp
89 cdef Py_ssize_t tmp_cols, start_row
90 tmp_cols = A_sub_ncols
91 if (tmp_cols < B_sub_ncols):
92 tmp_cols = B_sub_ncols # tmp_cols = max(A_sub_ncols, B_sub_ncols)
93 tmp = A.new_empty_window(7*A_sub_nrows + 4*A_sub_ncols, tmp_cols)
94
95 start_row = 0
96 S0 = tmp.matrix_window(start_row, 0, A_sub_nrows, A_sub_ncols)
97 start_row += A_sub_nrows
98 S1 = tmp.matrix_window(start_row, 0, A_sub_nrows, A_sub_ncols)
99 start_row += A_sub_nrows
100 S2 = tmp.matrix_window(start_row, 0, A_sub_nrows, A_sub_ncols)
101 start_row += A_sub_nrows
102 S3 = tmp.matrix_window(start_row, 0, A_sub_nrows, A_sub_ncols)
103 start_row += A_sub_nrows
104
105 T0 = tmp.matrix_window(start_row, 0, A_sub_ncols, B_sub_ncols)
106 start_row += A_sub_ncols
107 T1 = tmp.matrix_window(start_row, 0, A_sub_ncols, B_sub_ncols)
108 start_row += A_sub_ncols
109 T2 = tmp.matrix_window(start_row, 0, A_sub_ncols, B_sub_ncols)
110 start_row += A_sub_ncols
111 T3 = tmp.matrix_window(start_row, 0, A_sub_ncols, B_sub_ncols)
112 start_row += A_sub_ncols
113
114 Q0 = tmp.matrix_window(start_row, 0, A_sub_nrows, B_sub_ncols)
115 start_row += A_sub_nrows
116 Q1 = tmp.matrix_window(start_row, 0, A_sub_nrows, B_sub_ncols)
117 start_row += A_sub_nrows
118 Q2 = tmp.matrix_window(start_row, 0, A_sub_nrows, B_sub_ncols)
119
120
121 # Preparatory matrix additions/subtractions.
122
123 # todo: we can probably save some memory in these
124 # operations by reusing some of the buffers, if we interleave
125 # these additions with the multiplications (below)
126
127 # (I believe we can save on one S buffer and one T buffer)
128
129 # S0 = A10 + A11, T0 = B01 - B00
130 # S1 = S0 - A00, T1 = B11 - T0
131 # S2 = A00 - A10, T2 = B11 - B01
132 # S3 = A01 - S1, T3 = B10 - T1
133
134 S0.set_to_sum(A10, A11)
135 S1.set_to_diff(S0, A00)
136 S2.set_to_diff(A00, A10)
137 S3.set_to_diff(A01, S1)
138
139 T0.set_to_diff(B01, B00)
140 T1.set_to_diff(B11, T0)
141 T2.set_to_diff(B11, B01)
142 T3.set_to_diff(B10, T1)
143
144
145 # The relations we need now are:
146
147 # P0 = A00*B00
148 # P1 = A01*B10
149 # P2 = S0*T0
150 # P3 = S1*T1
151 # P4 = S2*T2
152 # P5 = S3*B11
153 # P6 = A11*T3
154
155 # U0 = P0 + P1
156 # U1 = P0 + P3
157 # U2 = U1 + P4
158 # U3 = U2 + P6
159 # U4 = U2 + P2
160 # U5 = U1 + P2
161 # U6 = U5 + P5
162
163 # We place the final answer into the matrix:
164 # U0 U6
165 # U3 U4
166
167 cdef MatrixWindow U0, U6, U3, U4
168 U0 = C.matrix_window(0, 0, A_sub_nrows, B_sub_ncols)
169 U6 = C.matrix_window(0, B_sub_ncols, A_sub_nrows, B_sub_ncols)
170 U3 = C.matrix_window(A_sub_nrows, 0, A_sub_nrows, B_sub_ncols)
171 U4 = C.matrix_window(A_sub_nrows, B_sub_ncols, A_sub_nrows, B_sub_ncols)
172
173 if (A_sub_nrows <= cutoff) or (A_sub_ncols <= cutoff) or (B_sub_ncols <= cutoff):
174 # This is the base case, so we use MatrixWindow methods directly.
175
176 # This next chunk is arranged so that each output cell gets written
177 # to exactly once. This is important because the output blocks might
178 # be quite fragmented in memory, whereas our temporary buffers
179 # (Q0, Q1, Q2) will be quite localised, so we can afford to do a bit
180 # of arithmetic in them.
181
182 Q0.set_to_prod(A00, B00) # now Q0 holds P0
183 Q1.set_to_prod(A01, B10) # now Q1 holds P1
184 U0.set_to_sum(Q0, Q1) # now U0 is correct
185 Q0.add_prod(S1, T1) # now Q0 holds U1
186 Q1.set_to_prod(S2, T2) # now Q1 holds P4
187 Q1.add(Q0) # now Q1 holds U2
188 Q2.set_to_prod(A11, T3) # now Q2 holds P6
189 U3.set_to_sum(Q1, Q2) # now U3 is correct
190 Q2.set_to_prod(S0, T0) # now Q2 holds P2
191 U4.set_to_sum(Q2, Q1) # now U4 is correct
192 Q0.add(Q2) # now Q0 holds U5
193 Q2.set_to_prod(S3, B11) # now Q2 holds P5
194 U6.set_to_sum(Q0, Q2) # now U6 is correct
195
196 else:
197 # Recurse into sub-products.
198
199 strassen_window_multiply_c(Q0, A00, B00, cutoff) # now Q0 holds P0
200 strassen_window_multiply_c(Q1, A01, B10, cutoff) # now Q1 holds P1
201 U0.set_to_sum(Q0, Q1) # now U0 is correct
202 strassen_window_multiply_c(Q1, S1, T1, cutoff) # now Q1 holds P3
203 Q0.add(Q1) # now Q0 holds U1
204 strassen_window_multiply_c(Q1, S2, T2, cutoff) # now Q1 holds P4
205 Q1.add(Q0) # now Q1 holds U2
206 strassen_window_multiply_c(Q2, A11, T3, cutoff) # now Q2 holds P6
207 U3.set_to_sum(Q1, Q2) # now U3 is correct
208 strassen_window_multiply_c(Q2, S0, T0, cutoff) # now Q2 holds P2
209 U4.set_to_sum(Q2, Q1) # now U4 is correct
210 Q0.add(Q2) # now Q0 holds U5
211 strassen_window_multiply_c(Q2, S3, B11, cutoff) # now Q2 holds P5
212 U6.set_to_sum(Q0, Q2) # now U6 is correct
213
214
215 # Now deal with the leftover row and/or column (if they exist).
216
217 cdef MatrixWindow B_last_col, C_last_col, B_bulk, A_last_row, C_last_row, B_last_row, A_last_col, C_bulk
218
219 if B_ncols & 1:
220 B_last_col = B.matrix_window(0, B_ncols-1, A_ncols, 1)
221 C_last_col = C.matrix_window(0, B_ncols-1, A_nrows, 1)
222 C_last_col.set_to_prod(A, B_last_col)
223
224 if A_nrows & 1:
225 A_last_row = A.matrix_window(A_nrows-1, 0, 1, A_ncols)
226 if B_ncols & 1:
227 B_bulk = B.matrix_window(0, 0, A_ncols, B_ncols-1)
228 C_last_row = C.matrix_window(A_nrows-1, 0, 1, B_ncols-1)
229 else:
230 B_bulk = B
231 C_last_row = C.matrix_window(A_nrows-1, 0, 1, B_ncols)
232 C_last_row.set_to_prod(A_last_row, B_bulk)
233
234 if A_ncols & 1:
235 A_last_col = A.matrix_window(0, A_ncols-1, A_sub_nrows << 1, 1)
236 B_last_row = B.matrix_window(A_ncols-1, 0, 1, B_sub_ncols << 1)
237 C_bulk = C.matrix_window(0, 0, A_sub_nrows << 1, B_sub_ncols << 1)
238 C_bulk.add_prod(A_last_col, B_last_row)
239
240cdef subtract_strassen_product(MatrixWindow result, MatrixWindow A, MatrixWindow B, Py_ssize_t cutoff):
241 cdef MatrixWindow to_sub
242 if (cutoff == -1 or result.ncols() <= cutoff or result.nrows() <= cutoff):
243 result.subtract_prod(A, B)
244 else:
245 to_sub = A.new_empty_window(result.nrows(), result.ncols())
246 strassen_window_multiply_c(to_sub, A, B, cutoff)
247 result.subtract(to_sub)
248
249
250def strassen_echelon(MatrixWindow A, cutoff):
251 """
252 Compute echelon form, in place.
253 Internal function, call with M.echelonize(algorithm="strassen")
254 Based on work of Robert Bradshaw and David Harvey at MSRI workshop in 2006.
255
256 INPUT:
257 A -- matrix window
258 cutoff -- size at which algorithm reverts to naive gaussian
259 elemination and multiplication must be at least 1.
260
261 OUTPUT:
262 The list of pivot columns
263
264 EXAMPLE:
265 sage: A = matrix(QQ, 7, [5, 0, 0, 0, 0, 0, -1, 0, 0, 1, 0, 0, 0, 0, 0, -1, 3, 1, 0, -1, 0, 0, -1, 0, 1, 2, -1, 1, 0, -1, 0, 1, 3, -1, 1, 0, 0, -2, 0, 2, 0, 1, 0, 0, -1, 0, 1, 0, 1])
266 sage: B = A.copy(); B._echelon_strassen(1); B
267 [ 1 0 0 0 0 0 0]
268 [ 0 1 0 -1 0 1 0]
269 [ 0 0 1 0 0 0 0]
270 [ 0 0 0 0 1 0 0]
271 [ 0 0 0 0 0 0 1]
272 [ 0 0 0 0 0 0 0]
273 [ 0 0 0 0 0 0 0]
274 sage: C = A.copy(); C._echelon_strassen(2); C == B
275 True
276 sage: C = A.copy(); C._echelon_strassen(4); C == B
277 True
278
279 sage: n = 32; A = matrix(Integers(389),n,range(n^2))
280 sage: B = A.copy(); B._echelon_in_place_classical()
281 sage: C = A.copy(); C._echelon_strassen(2)
282 sage: B == C
283 True
284
285 AUTHORS:
286 -- Robert Bradshaw
287 """
288 if cutoff < 1:
289 raise ValueError, "cutoff must be at least 1"
290 _sig_on
291 strassen_echelon_c(A, cutoff, A._matrix._strassen_default_cutoff(A._matrix))
292 _sig_off
293
294cdef strassen_echelon_c(MatrixWindow A, Py_ssize_t cutoff, Py_ssize_t mul_cutoff):
295 # The following notation will be used in the comments below, which should be understood to give
296 # the general idea of what's going on, as if there were no inconvenient non-pivot columns.
297 # The original matrix is given by [ A B ]
298 # [ C D ]
299 # For compactness, let A' denote the inverse of A
300 # top_left, top_right, bottom_left, and bottom_right loosely correspond to A, B, C, and D respectively,
301 # however, the "cut" between the top and bottom rows need not be the same.
302
303 cdef Py_ssize_t nrows, ncols
304 nrows = A.nrows()
305 ncols = A.ncols()
306
307 if (nrows <= cutoff or ncols <= cutoff):
308 return A.echelon_in_place()
309
310 cdef Py_ssize_t top_h, bottom_cut, bottom_h, bottom_start, top_cut
311 cdef Py_ssize_t prev_pivot_count
312 cdef Py_ssize_t split
313 split = nrows / 2
314
315 cdef MatrixWindow top, bottom, top_left, top_right, bottom_left, bottom_right, clear
316
317 top = A.matrix_window(0, 0, split, ncols)
318 bottom = A.matrix_window(split, 0, nrows-split, ncols)
319
320 top_pivots = strassen_echelon_c(top, cutoff, mul_cutoff)
321 # effectively "multiplied" top row by A^{-1}
322 # [ I A'B ]
323 # [ C D ]
324
325 top_pivot_intervals = int_range(top_pivots)
326 top_h = len(top_pivots)
327
328 if top_h == 0:
329 # [ 0 0 ]
330 # the whole top is a zero matrix, [ C D ]. Run echelon on the bottom
331 bottom_pivots = strassen_echelon_c(bottom, cutoff, mul_cutoff)
332 # [ 0 0 ]
333 # we now have [ I C'D ], proceed to sorting
334
335 else:
336 bottom_cut = max(top_pivots) + 1
337 bottom_left = bottom.matrix_window(0, 0, nrows-split, bottom_cut)
338
339 if top_h == ncols:
340 bottom.set_to_zero()
341 # [ I ]
342 # [ 0 ]
343 # proceed to sorting
344
345 else:
346 if bottom_cut == top_h:
347 clear = bottom_left
348 else:
349 clear = bottom_left.to_matrix().matrix_from_columns(top_pivots).matrix_window() # TODO: read only, can I do this faster? Also below
350 # Subtract off C time top from the bottom_right
351 if bottom_cut < ncols:
352 bottom_right = bottom.matrix_window(0, bottom_cut, nrows-split, ncols-bottom_cut)
353 subtract_strassen_product(bottom_right, clear, top.matrix_window(0, bottom_cut, top_h, ncols-bottom_cut), mul_cutoff);
354 # [ I A'B ]
355 # [ * D - CA'B ]
356
357 # Now subtract off C times the top from the bottom_left (pivots -> 0)
358 if bottom_cut == top_h:
359 bottom_left.set_to_zero()
360 bottom_start = bottom_cut
361
362 else:
363 for cols in top_pivot_intervals:
364 bottom_left.matrix_window(0, cols[0], nrows-split, cols[1]).set_to_zero()
365 non_pivots = int_range(0, bottom_cut) - top_pivot_intervals
366 for cols in non_pivots:
367 if cols[0] == 0: continue
368 prev_pivot_count = len(top_pivot_intervals - int_range(cols[0]+cols[1], bottom_cut - cols[0]+cols[1]))
369 subtract_strassen_product(bottom_left.matrix_window(0, cols[0], nrows-split, cols[1]),
370 clear.matrix_window(0, 0, nrows-split, prev_pivot_count),
371 top.matrix_window(0, cols[0], prev_pivot_count, cols[1]),
372 mul_cutoff)
373 bottom_start = non_pivots._intervals[0][0]
374 # [ I A'B ]
375 # [ 0 D - CA'B ]
376
377 # Now recursively do echelon form on the bottom
378 bottom_pivots_rel = strassen_echelon_c(bottom.matrix_window(0, bottom_start, nrows-split, ncols-bottom_start), cutoff, mul_cutoff)
379 # [ I A'B ]
380 # [ 0 I F ]
381 bottom_pivots = []
382 for pivot in bottom_pivots_rel:
383 bottom_pivots.append(pivot + bottom_start)
384 bottom_h = len(bottom_pivots)
385
386 if bottom_h + top_h == ncols:
387 top.matrix_window(0, bottom_cut, split, ncols-bottom_cut).set_to_zero()
388 # [ I 0 ]
389 # [ 0 I ]
390 # proceed to sorting
391
392 elif bottom_h == 0:
393 pass
394 # [ I A'B ]
395 # [ 0 0 ]
396 # proceed to sorting
397
398 else:
399 # [ I A'B ] = [ I E G ]
400 # let [ 0 I F ] = [ 0 I F ]
401 top_cut = max(max(bottom_pivots) + 1, bottom_cut)
402
403 # Note: left with respect to leftmost non-zero column of bottom
404 top_left = top.matrix_window(0, bottom_start, top_h, top_cut - bottom_start)
405
406 if top_cut - bottom_start == bottom_h:
407 clear = top_left
408 else:
409 clear = top_left.to_matrix().matrix_from_columns(bottom_pivots_rel).matrix_window()
410
411 # subtract off E times bottom from top right
412 if top_cut < ncols:
413 top_right = top.matrix_window(0, top_cut, top_h, ncols - top_cut)
414 subtract_strassen_product(top_right, clear, bottom.matrix_window(0, top_cut, bottom_h, ncols - top_cut), mul_cutoff);
415 # [ I * G - EF ]
416 # [ 0 I F ]
417
418 # Now subtract of E times bottom from top left
419 if top_cut - bottom_start == bottom_h:
420 top_left.set_to_zero()
421
422 else:
423 bottom_pivot_intervals = int_range(bottom_pivots)
424 for cols in bottom_pivot_intervals:
425 top.matrix_window(0, cols[0], top_h, cols[1]).set_to_zero()
426 non_pivots = int_range(bottom_start, top_cut - bottom_start) - bottom_pivot_intervals - top_pivot_intervals
427 for cols in non_pivots:
428 if cols[0] == 0: continue
429 prev_pivot_count = len(bottom_pivot_intervals - int_range(cols[0]+cols[1], top_cut - cols[0]+cols[1]))
430 subtract_strassen_product(top.matrix_window(0, cols[0], top_h, cols[1]),
431 clear.matrix_window(0, 0, top_h, prev_pivot_count),
432 bottom.matrix_window(0, cols[0], prev_pivot_count, cols[1]),
433 mul_cutoff)
434 # [ I 0 G - EF ]
435 # [ 0 I F ]
436 # proceed to sorting
437
438 # subrows already sorted...maybe I could do this more efficiently in cases with few pivot columns (e.g. merge sort)
439
440 pivots = top_pivots
441 pivots.extend(bottom_pivots)
442 pivots.sort()
443
444 cdef Py_ssize_t i, cur_row
445 for cur_row from 0 <= cur_row < len(pivots):
446 pivot = pivots[cur_row]
447 for i from cur_row <= i < nrows:
448 if not A.element_is_zero(i, pivot):
449 break
450 if i > cur_row and i < nrows:
451 A.swap_rows(i, cur_row)
452
453 return pivots
454
455
456
457################################
458# lots of room for optimization....
459# eventually, should I just pass these around rather than lists of ints for pivots?
460# would need new from_cols
461class int_range:
462 r"""
463 Useful class for dealing with pivots in the strassen echelon, could have much more general application
464 AUTHORS:
465 -- Robert Bradshaw
466
467 """
468 def __init__(self, indices=None, range=None):
469 if indices is None:
470 self._intervals = []
471 return
472 elif not range is None:
473 self._intervals = [(int(indices), int(range))]
474 else:
475 self._intervals = []
476 if len(indices) == 0:
477 return
478 indices.sort()
479 start = None
480 last = None
481 for ix in indices:
482 if last is None:
483 start = ix
484 elif ix-last > 1:
485 self._intervals.append((start, last-start+1))
486 start = ix
487 last = ix
488 self._intervals.append((start, last-start+1))
489
490 def __repr__(self):
491 return str(self._intervals)
492
493 def intervals(self):
494 return self._intervals
495
496 def to_list(self):
497 all = []
498 for iv in self._intervals:
499 for i in range(iv[0], iv[0]+iv[1]):
500 all.append(i)
501 return all
502
503 def __iter__(self):
504 return self._intervals.__iter__()
505
506 def __len__(self):
507 len = 0
508 for iv in self._intervals:
509 len = len + iv[1]
510 return len
511
512 # Yes, these two could be a lot faster...
513 # Basically, this class is for abstracting away what I was trying to do by hand in several places
514 def __add__(self, right):
515 all = self.to_list()
516 for i in right.to_list():
517 all.append(i)
518 return int_range(all)
519
520 def __sub__(self, right):
521 all = self.to_list()
522 for i in right.to_list():
523 if i in all:
524 all.remove(i)
525 return int_range(all)
526
527 def __mul__(self, right):
528 intersection = []
529 all = self.to_list()
530 for i in right.to_list():
531 if i in all:
532 intersection.append(i)
533 return int_range(intersection)
534
535
536
537
538"""
539Useful test code:
540
541def test(n, m, R, c=2):
542 A = matrix(R,n,m,range(n*m))
543 B = A.copy(); B._echelon_in_place_classical()
544 C = A.copy(); C._echelon_strassen(c)
545 return B == C
546
547E.g.,
548for n in range(5):
549 print n, test(2*n,n,Frac(QQ['x']),2)
550
551"""
552
553# This stuff gets tested extensively elsewhere, and the functions
554# below aren't callable now without using Pyrex.
555
556
557## todo: doc cutoff parameter as soon as I work out what it really means
558
559## EXAMPLES:
560## The following matrix dimensions are chosen especially to exercise the
561## eight possible parity combinations that ocould ccur while subdividing
562## the matrix in the strassen recursion. The base case in both cases will
563## be a (4x5) matrix times a (5x6) matrix.
564
565## TODO -- the doctests below are currently not
566## tested/enabled/working -- enable them when linear algebra
567## restructing gets going.
568
569## sage: dim1 = 64; dim2 = 83; dim3 = 101
570## sage: R = MatrixSpace(QQ, dim1, dim2)
571## sage: S = MatrixSpace(QQ, dim2, dim3)
572## sage: T = MatrixSpace(QQ, dim1, dim3)
573
574
575## sage: A = R.random_element(range(-30, 30))
576## sage: B = S.random_element(range(-30, 30))
577## sage: C = T(0)
578## sage: D = T(0)
579
580## sage: A_window = A.matrix_window(0, 0, dim1, dim2)
581## sage: B_window = B.matrix_window(0, 0, dim2, dim3)
582## sage: C_window = C.matrix_window(0, 0, dim1, dim3)
583## sage: D_window = D.matrix_window(0, 0, dim1, dim3)
584
585## sage: from sage.matrix.strassen import strassen_window_multiply
586## sage: strassen_window_multiply(C_window, A_window, B_window, 2) # use strassen method
587## sage: D_window.set_to_prod(A_window, B_window) # use naive method
588## sage: C_window == D_window
589## True
590
591## sage: dim1 = 79; dim2 = 83; dim3 = 101
592## sage: R = MatrixSpace(QQ, dim1, dim2)
593## sage: S = MatrixSpace(QQ, dim2, dim3)
594## sage: T = MatrixSpace(QQ, dim1, dim3)
595
596## sage: A = R.random_element(range(30))
597## sage: B = S.random_element(range(30))
598## sage: C = T(0)
599## sage: D = T(0)
600
601## sage: A_window = A.matrix_window(0, 0, dim1, dim2)
602## sage: B_window = B.matrix_window(0, 0, dim2, dim3)
603## sage: C_window = C.matrix_window(0, 0, dim1, dim3)
604
605## sage: strassen_window_multiply(C, A, B, 2) # use strassen method
606## sage: D.set_to_prod(A, B) # use naive method
607
608## sage: C == D
609## True
610