When running this code on N > 1024, I get a bus error/core dumped error. I am using a remote HPC and gcc/8.1. This is a matrix multiplication NxN. I don't understand where the error comes from. Specifically why there's nothing wrong with the smaller Ns. I had other codes running up to 2^20 before.
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <sys/time.h>
#define N 2048
float *A[N], *B[N];
int i, j, k, count = 0;
float** matrix_create(int n){
float** M = malloc(n * sizeof(float*));
for (i = 0; i < n; i++)
M[i] = (float*)malloc(n * sizeof(float));
return M;
}
float** add(float* M1[], float* M2[], int n){
float** M3 = matrix_create(n);
for (i = 0; i < n; i++)
for (j = 0; j < n; j++)
M3[i][j] = M1[i][j] + M2[i][j];
return M3;
}
float** sub(float* M1[], float* M2[], int n){
float** M3 = matrix_create(n);
for (i = 0; i < n; i++)
for (j = 0; j < n; j++)
M3[i][j] = M1[i][j] - M2[i][j];
return M3;
}
void print(float* M[], int n){
for (i = 0; i < n; i++){
for (j = 0; j < n; j++)
printf("%f\t", M[i][j] );
printf("\n");
}
}
float** strassen_multiply(float* A[], float* B[], int n){
if(n == 1 ){
float** C = matrix_create(n);
C[0][0] = A[0][0] * B[0][0];
return C;
}
count++;
float** C = matrix_create(n);
int k = n/2;
/** Creating sub matrecies**/
float** A11 = matrix_create(k);
float** A12 = matrix_create(k);
float** A21 = matrix_create(k);
float** A22 = matrix_create(k);
float** B11 = matrix_create(k);
float** B12 = matrix_create(k);
float** B21 = matrix_create(k);
float** B22 = matrix_create(k);
/**Dividing the Data Matrecies A & B**/
for(i = 0; i < k; i++)
for(j = 0; j < k; j++){
A11[i][j] = A[i][j];
A12[i][j] = A[i][k+j];
A21[i][j] = A[k+i][j];
A22[i][j] = A[k+i][k+j];
B11[i][j] = B[i][j];
B12[i][j] = B[i][k+j];
B21[i][j] = B[k+i][j];
B22[i][j] = B[k+i][k+j];
}
float** P1 = strassen_multiply(A11, sub(B12, B22, k), k);
float** P2 = strassen_multiply(add(A11, A12, k), B22, k);
float** P3 = strassen_multiply(add(A21, A22, k), B11, k);
float** P4 = strassen_multiply(A22, sub(B21, B11, k), k);
float** P5 = strassen_multiply(add(A11, A22, k), add(B11, B22, k), k);
float** P6 = strassen_multiply(sub(A12, A22, k), add(B21, B22, k), k);
float** P7 = strassen_multiply(sub(A11, A21, k), add(B11, B12, k), k);
float** C11 = sub(add(add(P5, P4, k), P6, k), P2, k);
float** C12 = add(P1, P2, k);
float** C21 = add(P3, P4, k);
float** C22 = sub(sub(add(P5, P1, k), P3, k), P7, k);
for(i = 0; i < k; i++)
for(j = 0; j < k; j++){
C[i][j] = C11[i][j];
C[i][j+k] = C12[i][j];
C[k+i][j] = C21[i][j];
C[k+i][k+j] = C22[i][j];
}
for(i = 0; i < k; i++){
free( A11[i]);
free( A12[i]);
free( A21[i]);
free( A22[i]);
free( B11[i]);
free( B12[i]);
free( B21[i]);
free( B22[i]);
free( P1[i]);
free( P2[i]);
free( P3[i]);
free( P4[i]);
free( P5[i]);
free( P6[i]);
free( P7[i]);
free( C11[i]);
free( C12[i]);
free( C21[i]);
free( C22[i]);
}
free( A11);
free( A12);
free( A21);
free( A22);
free( B11);
free( B12);
free( B21);
free( B22);
free( P1);
free( P2);
free( P3);
free( P4);
free( P5);
free( P6);
free( P7);
free( C11);
free( C12);
free( C21);
free( C22);
return C;
}
int main(){
int i,j, k;
struct timeval begin, end;
for (i = 0; i < N; i++)
A[i] = (float*)malloc(N * sizeof(float));
for (i = 0; i < N; i++)
B[i] = (float*)malloc(N * sizeof(float));
for (i = 0; i < N; i++)
for (j = 0; j < N; j++){
A[i][j] = -1+2*((float)rand())/RAND_MAX;
B[i][j] = -1+2*((float)rand())/RAND_MAX;
}
float** C = matrix_create(N);
gettimeofday(&begin, 0);
C = strassen_multiply(A, B, N);
gettimeofday(&end, 0);
long seconds = end.tv_sec - begin.tv_sec;
long microseconds = end.tv_usec - begin.tv_usec;
double elapsed = seconds + microseconds*1e-6;
printf("number of recursion: %d\n\n", count);
printf("Total wall time: %f\n", elapsed);
}
Transferring some comments from the chat.
Diagnosis
You're not checking that your memory is allocated successfully. You don't know whether everything worked.
You start off with two 2048x2048 float matrices. Your strassen_multiply() function then (1) creates 8 matrices each with half the size (in terms of number of rows and columns), loads them, and then recurses 7 times in a row. Each of those recursions also creates a load of matrices — I've not sat down and calculated the total space required, but it is going to be considerable. You really need to check that your memory allocation is working. It may be that your 64-bit machine has enough space that it isn't a problem (the two initial matrices require 32 MiB of data, which may be OK).
You have calls like
float** P1 = strassen_multiply(A11, sub(B12, B22, k), k);
float** P2 = strassen_multiply(add(A11, A12, k), B22, k);
You have no way to free the matrix returned by the nested calls to sub() and add(). You can't afford not to free that memory. So, you're leaking large quantities of memory. You need a function to free your matrices — and arguably a matrix structure type that records the size of the matrix since you're going to need the size in the function to free a matrix.
You check that memory was allocated by checking for a null pointer returned by malloc(). On most systems, that's reliable. On Linux, it has the OOM (Out of Memory) manager and tends to return a non-null pointer and later crashes when you try to use the memory that it told you was available but actually wasn't. I regard that as highly undesirable behaviour, but … If you fail to allocate one of the rows, don't forget to release any previously allocated rows in that matrix.
You can't use global matrices; you have to return matrices from functions, and you have recursive functions, so global matrices won't work. You need to convert your matrices (which are all square matrices) into a structure such as:
struct Matrix
{
int size;
float **data;
};
Your existing two global arrays of pointers to float should be replaced — otherwise, you'll need special code to release the memory allocated to them.
In main() you have:
float** C = matrix_create(N);
…
C = strassen_multiply(A, B, N);
so you're leaking a full-size matrix.
The functions returning a matrix will return a matrix structure, and the ones that take two matrix arguments will be taking two pointers to (constant) matrix structures as arguments. The outlined matrix structure is so small there isn't a benefit to returning a pointer to a matrix structure.
In your current code for main(), you should have:
float **A = matrix_create(N);
float **B = matrix_create(N);
Your matrix C in the main() should be created with:
float **C = strassen_multiply(A, B, N);
The matrix C never was global.
Use matrix_create() as you do now. Just remember to free the returned value in the function that calls add() or sub(), which also means you'll need to save those intermediate results in local variables so that you can free them.
You're using global variables i, j, k for your array indices. All hell is going to break loose. Array indices must be local variables, especially if you use recursion.
That means you have to declare loop variables in each function. You should write
for (int i = 0; i < n; i++)
or equivalent for each loop. This will be more efficient than using global variables; it also gives your code a better chance of being correct. As it stands, you've not got the slightest chance of the code being correct.
Prescription
Putting those points together yields code like this:
#include <assert.h>
#include <errno.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/time.h>
#ifndef N
#define N 128
#endif
typedef struct Matrix
{
int size;
float **data;
} Matrix;
static int count = 0;
static size_t cnt_create = 0;
static size_t cnt_destroy = 0;
static size_t cnt_add = 0;
static size_t cnt_sub = 0;
static void err_nomemory(const char *file, const char *func, int line, size_t size)
{
fprintf(stderr, "%s:%s():%d: out of memory attempting to allocate %zu bytes "
"(%d: %s)\n", file, func, line, size, errno, strerror(errno));
exit(EXIT_FAILURE);
}
static void matrix_destroy(Matrix *M)
{
cnt_destroy++;
for (int i = 0; i < M->size; i++)
free(M->data[i]);
free(M->data);
}
static Matrix matrix_create(int n)
{
cnt_create++;
Matrix M = { .size = n, .data = malloc(n * sizeof(float *)) };
if (M.data == NULL)
err_nomemory(__FILE__, __func__, __LINE__, n * sizeof(float *));
for (int i = 0; i < n; i++)
{
if ((M.data[i] = (float *)malloc(n * sizeof(float))) == NULL)
err_nomemory(__FILE__, __func__, __LINE__, n * sizeof(float));
}
return M;
}
static Matrix add(const Matrix *M1, const Matrix *M2)
{
cnt_add++;
assert(M1->size == M2->size);
Matrix M3 = matrix_create(M1->size);
for (int i = 0; i < M1->size; i++)
{
for (int j = 0; j < M1->size; j++)
M3.data[i][j] = M1->data[i][j] + M2->data[i][j];
}
return M3;
}
static Matrix sub(const Matrix *M1, const Matrix *M2)
{
cnt_sub++;
assert(M1->size == M2->size);
Matrix M3 = matrix_create(M1->size);
for (int i = 0; i < M1->size; i++)
{
for (int j = 0; j < M1->size; j++)
M3.data[i][j] = M1->data[i][j] - M2->data[i][j];
}
return M3;
}
static void matrix_print(const char *tag, const Matrix *M)
{
printf("%s (%dx%d):\n", tag, M->size, M->size);
if (M->size > 128)
{
printf("Printing suppressed - matrix too large\n");
return;
}
char buffer[32];
int len = snprintf(buffer, sizeof(buffer), "%d", M->size);
for (int i = 0; i < M->size; i++)
{
printf("[%*d]: ", len, i);
const char *pad = "";
for (int j = 0; j < M->size; j++)
{
printf("%s%f", pad, M->data[i][j]);
pad = "\t";
}
printf("\n");
}
}
static Matrix strassen_multiply(const Matrix *A, const Matrix *B)
{
assert(A->size == B->size);
if (A->size == 1)
{
Matrix C = matrix_create(A->size);
C.data[0][0] = A->data[0][0] * B->data[0][0];
return C;
}
count++;
Matrix C = matrix_create(A->size);
int k = A->size / 2;
/** Creating sub matrices**/
Matrix A11 = matrix_create(k);
Matrix A12 = matrix_create(k);
Matrix A21 = matrix_create(k);
Matrix A22 = matrix_create(k);
Matrix B11 = matrix_create(k);
Matrix B12 = matrix_create(k);
Matrix B21 = matrix_create(k);
Matrix B22 = matrix_create(k);
/** Dividing the Data Matrices A & B **/
for (int i = 0; i < k; i++)
{
for (int j = 0; j < k; j++)
{
A11.data[i][j] = A->data[i + 0][j + 0];
A12.data[i][j] = A->data[i + 0][k + j];
A21.data[i][j] = A->data[k + i][j + 0];
A22.data[i][j] = A->data[k + i][k + j];
B11.data[i][j] = B->data[i + 0][j + 0];
B12.data[i][j] = B->data[i + 0][k + j];
B21.data[i][j] = B->data[k + i][j + 0];
B22.data[i][j] = B->data[k + i][k + j];
}
}
Matrix T1 = sub(&B12, &B22);
Matrix P1 = strassen_multiply(&A11, &T1);
matrix_destroy(&T1);
Matrix T2 = add(&A11, &A12);
Matrix P2 = strassen_multiply(&T2, &B22);
matrix_destroy(&T2);
Matrix T3 = add(&A21, &A22);
Matrix P3 = strassen_multiply(&T3, &B11);
matrix_destroy(&T3);
Matrix T4 = sub(&B21, &B11);
Matrix P4 = strassen_multiply(&A22, &T4);
matrix_destroy(&T4);
Matrix T5A = add(&A11, &A22);
Matrix T5B = add(&B11, &B22);
Matrix P5 = strassen_multiply(&T5A, &T5B);
matrix_destroy(&T5A);
matrix_destroy(&T5B);
Matrix T6A = sub(&A12, &A22);
Matrix T6B = add(&B21, &B22);
Matrix P6 = strassen_multiply(&T6A, &T6B);
matrix_destroy(&T6A);
matrix_destroy(&T6B);
Matrix T7A = sub(&A11, &A21);
Matrix T7B = add(&B11, &B12);
Matrix P7 = strassen_multiply(&T7A, &T7B);
matrix_destroy(&T7A);
matrix_destroy(&T7B);
matrix_destroy(&A11);
matrix_destroy(&A12);
matrix_destroy(&A21);
matrix_destroy(&A22);
matrix_destroy(&B11);
matrix_destroy(&B12);
matrix_destroy(&B21);
matrix_destroy(&B22);
Matrix C1A = add(&P5, &P4);
Matrix C1B = add(&C1A, &P6);
Matrix C11 = sub(&C1B, &P2);
Matrix C12 = add(&P1, &P2);
Matrix C21 = add(&P3, &P4);
Matrix C2A = add(&P5, &P1);
Matrix C2B = sub(&C2A, &P3);
Matrix C22 = sub(&C2B, &P7);
matrix_destroy(&C1A);
matrix_destroy(&C1B);
matrix_destroy(&C2A);
matrix_destroy(&C2B);
matrix_destroy(&P1);
matrix_destroy(&P2);
matrix_destroy(&P3);
matrix_destroy(&P4);
matrix_destroy(&P5);
matrix_destroy(&P6);
matrix_destroy(&P7);
for (int i = 0; i < k; i++)
{
for (int j = 0; j < k; j++)
{
C.data[i + 0][j + 0] = C11.data[i][j];
C.data[i + 0][j + k] = C12.data[i][j];
C.data[k + i][j + 0] = C21.data[i][j];
C.data[k + i][k + j] = C22.data[i][j];
}
}
matrix_destroy(&C11);
matrix_destroy(&C12);
matrix_destroy(&C21);
matrix_destroy(&C22);
return C;
}
int main(void)
{
struct timeval begin, end;
Matrix A = matrix_create(N);
Matrix B = matrix_create(N);
for (int i = 0; i < N; i++)
{
for (int j = 0; j < N; j++)
{
A.data[i][j] = -1.0 + 2.0 * ((double)rand()) / RAND_MAX;
B.data[i][j] = -1.0 + 2.0 * ((double)rand()) / RAND_MAX;
}
}
gettimeofday(&begin, 0);
Matrix C = strassen_multiply(&A, &B);
gettimeofday(&end, 0);
matrix_print("A", &A);
matrix_print("B", &B);
matrix_print("C", &C);
matrix_destroy(&A);
matrix_destroy(&B);
matrix_destroy(&C);
long seconds = end.tv_sec - begin.tv_sec;
long microseconds = end.tv_usec - begin.tv_usec;
double elapsed = seconds + microseconds * 1e-6;
printf("Number of non-minimal recursive calls: %d\n", count);
printf("Number of matrices created: %zu\n", cnt_create);
printf("Number of matrices destroyed: %zu\n", cnt_destroy);
printf("Number of matrix additions: %zu\n", cnt_add);
printf("Number of matrix subtractions: %zu\n", cnt_sub);
printf("Total wall time: %f\n", elapsed);
return 0;
}
This cheats on detecting the memory allocation errors by calling a function that simply exits, rather than freeing any successfully allocated memory and returning to the caller.
The code can be compiled with -DN=256 or any other power of two. It isn't clear what would happen if the size is not a power of 2.
Performance
Some sample times and other statistics for various sizes:
N=8
Number of non-minimal recursive calls: 57
Number of matrices created: 1884
Number of matrices destroyed: 1884
Number of matrix additions: 627
Number of matrix subtractions: 399
Total wall time: 0.000480
N=16
Number of non-minimal recursive calls: 400
Number of matrices created: 13203
Number of matrices destroyed: 13203
Number of matrix additions: 4400
Number of matrix subtractions: 2800
Total wall time: 0.003723
N=32
Number of non-minimal recursive calls: 2801
Number of matrices created: 92436
Number of matrices destroyed: 92436
Number of matrix additions: 30811
Number of matrix subtractions: 19607
Total wall time: 0.025097
N=64
Number of non-minimal recursive calls: 19608
Number of matrices created: 647067
Number of matrices destroyed: 647067
Number of matrix additions: 215688
Number of matrix subtractions: 137256
Total wall time: 0.161971
N=128
Number of non-minimal recursive calls: 137257
Number of matrices created: 4529484
Number of matrices destroyed: 4529484
Number of matrix additions: 1509827
Number of matrix subtractions: 960799
Total wall time: 1.164555
N=256
Number of non-minimal recursive calls: 960800
Number of matrices created: 31706403
Number of matrices destroyed: 31706403
Number of matrix additions: 10568800
Number of matrix subtractions: 6725600
Total wall time: 7.632881
N=512
Number of non-minimal recursive calls: 6725601
Number of matrices created: 221944836
Number of matrices destroyed: 221944836
Number of matrix additions: 73981611
Number of matrix subtractions: 47079207
Total wall time: 53.730002
N=1024
Number of non-minimal recursive calls: 47079208
Number of matrices created: 1553613867
Number of matrices destroyed: 1553613867
Number of matrix additions: 517871288
Number of matrix subtractions: 329554456
Total wall time: 373.596480
N=2048
Number of non-minimal recursive calls: 329554457
Number of matrices created: 10875297084
Number of matrices destroyed: 10875297084
Number of matrix additions: 3625099027
Number of matrix subtractions: 2306881199
Total wall time: 2737.750096
Note that the number of matrices created is the same as the number destroyed; that's reassuring. Note too that there are massive numbers of matrices created and destroyed.
However, doubling the size of the matrices being multiplied is not giving a cubic time; it is better than O(N³), whereas the naïve algorithm is O(N³).
Improving Performance
One way to improve the speed of the code is to special-case 2x2 matrix multiplication. When implemented, that gave results like:
N=16
Number of large multiplications: 57
Number of 1x1 multiplications: 0
Number of 2x2 multiplications: 343
Number of matrices created: 1884
Number of matrices destroyed: 1884
Number of matrix additions: 627
Number of matrix subtractions: 399
Total wall time: 0.001045
N=32
Number of large multiplications: 400
Number of 1x1 multiplications: 0
Number of 2x2 multiplications: 2401
Number of matrices created: 13203
Number of matrices destroyed: 13203
Number of matrix additions: 4400
Number of matrix subtractions: 2800
Total wall time: 0.006532
N=64
Number of large multiplications: 2801
Number of 1x1 multiplications: 0
Number of 2x2 multiplications: 16807
Number of matrices created: 92436
Number of matrices destroyed: 92436
Number of matrix additions: 30811
Number of matrix subtractions: 19607
Total wall time: 0.038640
N=128
Number of large multiplications: 19608
Number of 1x1 multiplications: 0
Number of 2x2 multiplications: 117649
Number of matrices created: 647067
Number of matrices destroyed: 647067
Number of matrix additions: 215688
Number of matrix subtractions: 137256
Total wall time: 0.263008
N=256
Number of large multiplications: 137257
Number of 1x1 multiplications: 0
Number of 2x2 multiplications: 823543
Number of matrices created: 4529484
Number of matrices destroyed: 4529484
Number of matrix additions: 1509827
Number of matrix subtractions: 960799
Total wall time: 1.796228
N=512
Number of large multiplications: 960800
Number of 1x1 multiplications: 0
Number of 2x2 multiplications: 5764801
Number of matrices created: 31706403
Number of matrices destroyed: 31706403
Number of matrix additions: 10568800
Number of matrix subtractions: 6725600
Total wall time: 12.383302
For comparison, the number of matrices created and destroyed with the 1x1 and 2x2 special cases is:
N 1x1 2x2
16 13,203 1,884
32 92,436 13,203
64 647,067 92,436
128 4,528,484 647,067
256 31,706,403 4,529,484
512 221,944,836 31,706,403
Observe that the number of matrices created for with the 1x1 minimum case for multiplying NxN matrices is the same as for the 2x2 minimum case with 2Nx2N matrices. It also provides a fairly dramatic speed-up (c.f. 53.73 seconds for N=512 with 1x1 versus 12.38 seconds for N=512 with 2x2). A lot of the original cost is in creating 1x1 matrices to multiply together.
Other recommendations
Unslander Monica suggested:
Sub-matrices should be copied only when they are used a lot — to improve cache locality. Otherwise, a "sub matrix" is not a variable, but a concept. That means that when you do any matrix operation, you should be passing some object that describes the range of indices, and the index stride, e.g. to matrix multiplication. That way you won't be creating those sub-matrices. In general, there's lots to be done to make this code reasonable.
This would make the matrix structures more complex, but would also radically improve performance. You'd probably end up with matrix_create() returning a Matrix *, and the structure would contain extra elements: int tl_r; int tl_c; int br_r; int br_c; (top-left row and column, bottom-left row and column). You'd have another function to split a matrix into 4 quarter matrices, which would all reference the data of the unsplit matrix but with different values for the top-left and bottom-right coordinates of the sub-matrix. If you continue with the current array of pointers to arrays of floats organization, you don't need to record the 'stride' (the width of each row in the original array, which is also the height since this deals only with square matrices). You'd have to be careful with the memory management. Result arrays would be created afresh. You won't be releasing the data from quarter matrices — only from those created afresh.
Asteroids With Wings commented:
Why do you use arrays of pointers for square arrays? That's a lot of overhead for no reason. Just create an array of N*N floats! Then you can start simplifying all this crazy memory management.
And there is some justice in that, though care would be required. I still think you'd be using a structure, but the data element would be float * instead of float **, and you'd compute the array indexes (row * width + col) instead of using two subscripts. If you forego structures, you might be able to use 'variable length array' (VLA) notation instead. Care would be required. The arrays would still be dynamically allocated.
Further experiments and Suggestions
I've experimented with both 4x4 and 8x8 special cases, and both provide considerable benefit because of the reduced memory management overhead (many fewer matrix allocations and destructions). Multiplying 1024x1024 matrices with different minimum sizes gives:
Size Time #Matrices
1x1 6m 32s 1,553,613,867
2x2 1m 31s 221,944,836
4x4 23s 31,706,403
8x8 7s 4,529,484
I also coded a version that does a straight-forward raw matrix multiplication (O(N³) algorithm — using the same code as I used for 8x8 multiplication for NxN), and it is quite a bit faster than the Strassen algorithm, mainly because there's almost no memory management required.
Size Time
128x128 3ms
256x256 25ms
512x512 280ms
1024x1024 1,802ms
2048x2048 84,686ms
4096x4096 850,860ms
Note that the time multiplication between 1.80s at 1024 and 84.7s at 2048 is bigger than a cubic relation (a factor of 8 that more or less applies otherwise) — I've not investigated the cause.
I think the key to speeding up from here is not copying matrices — using what Unslander Monica suggested. I note that you probably don't need 4 coordinates as I suggested earlier; 2 coordinates and the size are sufficient (because the matrices are all square). That reduces the copying as well as the memory allocation — that will have a major benefit on performance.
I don't think Strassen is proved a failure yet. I think it shows that the naïve memory management you (we) are using is not the way to go. But I also think that unless your matrices are bigger than 1024x1024, it is likely that the naïve multiplication algorithm is probably sufficiently fast. If you must deal with big arrays, then the Strassen algorithm may be beneficial. More coding and testing needed!
TL;DR to TL;DR: This code is far from state of the art. Writing one's own matrix multiplication is replicating the work that has already been done, and that in aggregate took a long time to do - many man years of effort. There's no reason you should be writing your own. None. It only makes sense if you know the state of the art (read all the papers about this - it's a subject mostly beaten to death by now), and if you think you can do better. If all you want to do is to multiply matrices for some application, then consider it done and look for code that fits your use case. If there's none, you'll still do better to take existing code and modify it.
TL;DR: The code below does a 2048x2048 multiply in 2.2s on godbolt, using gcc 8.1 with the following options: -lm -Wall -O3 -march=knl. Using gcc 10.1, with same options, cuts the time in half to 1.1s. I'm not sure if gcc 10.1 still produces code that does all the work through - modern compilers are clever enough to figure out that this benchmark uses data produced locally, and pretty much optimize the whole program to one single function, so they may do optimizations that wouldn't be possible if the data was e.g. read from a file. From cursory glance it looks like transposition doesn't do much because the compiler reorders memory accesses anyway.
Do use godbolt to look at the assembly output - even the code below produces assembly that is vectorized, and it's optimized for the exact size of the matrix since the compiler can propagate this "N = 2048" all over the place and doesn't generate code for other cases since there are none - it can internally prove this to be the case. Compilers are very clever, complex systems these days.
Note that this code is not fully optimized yet (I don't have the weeks it'd take me to figure it all out on my own, and copying existing matrix code is pointless as an answer). I was executing it on godbolt for convenience's sake, instead of locally, so it would still take some memory layout changes to make it perform even better, especially on problems significantly larger than 2kx2k.
It's also available from this gist.
Note: this code won't compile without optimization enabled (at least -O) - the inline functions are a problem then and you get linker errors, it seems. To get good performance you need -O3 anyway.
Example output:
Allocating (0) 2048x2048 0x7fe5ab851010
Allocating (1) 2048x2048 0x7fe5aa850010
Allocating (2) 2048x2048 0x7fe5a984f010
Memory used for A,B,C matrices: 50331744
Allocating (3) 1024x1024 0x7fe5a944e010
Allocating (4) 1024x1024 0x7fe5a904d010
Allocating (5) 1024x1024 0x7fe5a8c4c010
Freeing Matrix (6) 1024x1024 0x7fe5a944e010
Freeing Matrix (5) 1024x1024 0x7fe5a904d010
Freeing Matrix (4) 1024x1024 0x7fe5a8c4c010
Freeing Matrix (3) 2048x2048 0x7fe5a984f010
Freeing Matrix (2) 2048x2048 0x7fe5aa850010
Freeing Matrix (1) 2048x2048 0x7fe5ab851010
Number of entries to Strassen multiplication: 960800
Total wall time: 1.98604s
Maximum allocated matrices/memory: 6 / 62914752
Matrices left: 0
First, some boilerplate and object lifetime tracking scaffolding - that can be enabled for diagnostic purposes to detect memory leaks and use of invalid objects.
// complete compileable example begins
#include <assert.h>
#include <math.h>
#include <stdarg.h>
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/time.h>
#define TRACK_MEMORY_USE 1
#define TRACK_LIFETIME 0
#define TRACK_VIEWS 0
#define TRACK_DUMPS 0
#define TRACK_ALLOCS 1
#define KERNEL_SIZE 16 // 16-32 is the sweet spot usually
#define TRANSPOSE_B 1 // enabling it improves things very slightly
#if defined(__GNUC__) || defined(__clang__)
#define _nodiscard_ __attribute__((warn_unused_result))
#else
#define _nodiscard_
#endif
enum TrackEvent {
TE_CREATE, // param = size of object
TE_USE,
TE_DESTROY, // param = size of object
TE_ISEMPTY,
TE_COUNT,
TE_ALLOC,
TE_MAX_COUNT,
TE_MAX_ALLOC,
};
#if TRACK_MEMORY_USE || TRACK_LIFETIME
size_t obj_track(enum TrackEvent event, const void *obj, size_t param);
#else
size_t obj_track(enum TrackEvent event, const void *obj, size_t param) { return 1; }
#endif
#define mat_use_check(M) do { \
/* A view of a view must still refer directly to the shown matrix. */ \
assert(!M->shown || !M->shown->shown); \
assert(obj_track(TE_USE, mat_shown(M), 0)); \
} while (0)
Then the definition of the Matrix itself, and some forward-declared API.
struct Matrix {
struct Matrix *shown; // a matrix being viewed, if any
float *ptr;
int n; // rows and columns in this range
int row_stride; // distance between beginnings of rows (units: elements)
int16_t tmp_count; // number of active temporary uses of this object
} typedef Matrix;
int strassen_entry_count;
//! Returns the matrix being shown if M is a view, otherwise return M itself
inline const Matrix *mat_shown(const Matrix *M) {
return M->shown ? M->shown : M;
}
Matrix *mat_create(int n);
void free_all(Matrix **M1, ...);
void free_temp(Matrix *M);
void free_all_temp(Matrix *M1, ...);
static bool mat_same_layouts(const Matrix *A, const Matrix *B);
void mat_print(const Matrix *M);
Matrix *mat_transpose(Matrix *M);
Matrix mat_block_view(const Matrix *matrix, int i, int j, int nbl);
Matrix mat_linear_block_view(const Matrix *matrix, int i, int j, int nbl);
Matrix *mat_add_to(Matrix *C, Matrix *A);
Matrix *mat_sub_from(Matrix *C, Matrix *A);
Matrix *mat_sum_to(Matrix *C, Matrix *A, Matrix *B);
Matrix *mat_diff_to(Matrix *C, Matrix *A, Matrix *B);
Now for the Strassen multiplication:
//
// Multiplication
//
static void mat_mul_impl_(float *restrict C, const float *restrict A, const float *restrict B,
int C_row_stride, int A_row_stride, int B_row_stride);
Matrix *mat_strassen_mul_impl_(Matrix *C, Matrix *A, Matrix *B, const Matrix *T) {
++ strassen_entry_count;
mat_use_check(C);
mat_use_check(A);
mat_use_check(B);
if (T) mat_use_check(T);
int const N = C->n;
assert(N >= KERNEL_SIZE && N == A->n && N == B->n && (!T || N <= T->n));
if (N == KERNEL_SIZE) {
mat_mul_impl_(C->ptr, A->ptr, B->ptr, C->row_stride, A->row_stride, B->row_stride);
} else {
Matrix A11 = mat_block_view(A, 0, 0, 2);
Matrix A12 = mat_block_view(A, 0, 1, 2);
Matrix A21 = mat_block_view(A, 1, 0, 2);
Matrix A22 = mat_block_view(A, 1, 1, 2);
Matrix B11 = mat_block_view(B, 0, 0, 2);
#if TRANSPOSE_B
Matrix B12 = mat_block_view(B, 1, 0, 2); // transposed
Matrix B21 = mat_block_view(B, 0, 1, 2); // transposed
#else
Matrix B12 = mat_block_view(B, 0, 1, 2);
Matrix B21 = mat_block_view(B, 1, 0, 2);
#endif
Matrix B22 = mat_block_view(B, 1, 1, 2);
Matrix C11 = mat_block_view(C, 0, 0, 2);
Matrix C12 = mat_block_view(C, 0, 1, 2);
Matrix C21 = mat_block_view(C, 1, 0, 2);
Matrix C22 = mat_block_view(C, 1, 1, 2);
// T1 == C12, T2 == C21, T3,T4,T5 : new
// C11 = (M7) = (A12-A22) * (B21+B22) // lease T3
// C22 = (M6) = (A21-A11) * (B11+B12) // lease T3
// T4 = (M1) = (A11+A22) * (B11+B22) // lease T3
// C11 = M7 + M1
// C22 = M6 + M1
// C12 = (M5) = (A11+A12) * B22 // lease T3
// C11 = M7 + M1 - M5
// C21 = (M2) = (A21+A22) * B11 // lease T3
// C22 = M6 + M1 - M2
// T4 = (M3) = A11 * (B12-B22) // lease T3
// C12 = M5 + M3
// C22 = M6 + M1 - M2 + M3
// T4 = (M4) = A22 * (B21-B11) // lease T3
// C11 = M7 + M1 - M5 + M4
// C21 = M2 + M4
Matrix T3_, T4_, T5_, *T3 = NULL, *T4 = NULL, *T5 = NULL;
if (T) {
T3_ = mat_linear_block_view(T, 0, 0, 2);
T4_ = mat_linear_block_view(T, 0, 1, 2);
T5_ = mat_linear_block_view(T, 1, 0, 2);
T3 = &T3_;
T4 = &T4_;
T5 = &T5_;
} else {
T3 = mat_create(A11.n);
T4 = mat_create(A11.n);
T5 = mat_create(A11.n);
}
{
Matrix *M1 = &C12;
/*M7*/ mat_strassen_mul_impl_(&C11, mat_diff_to(T4, &A12, &A22), mat_sum_to(T5, &B21, &B22), T3);
/*M6*/ mat_strassen_mul_impl_(&C22, mat_diff_to(T4, &A21, &A11), mat_sum_to(T5, &B11, &B12), T3);
/*M1*/ mat_strassen_mul_impl_(M1, mat_sum_to(T4, &A11, &A22), mat_sum_to(T5, &B11, &B22), T3);
mat_add_to(&C11, M1);
mat_add_to(&C22, M1);
}
{
Matrix *M5 = mat_strassen_mul_impl_(&C12, mat_sum_to(T5, &A11, &A12), &B22, T3);
mat_sub_from(&C11, M5);
Matrix *M2 = mat_strassen_mul_impl_(&C21, mat_sum_to(T5, &A21, &A22), &B11, T3);
mat_sub_from(&C22, M2);
}
{
Matrix *M3 = mat_strassen_mul_impl_(T4, &A11, mat_diff_to(T5, &B12, &B22), T3);
mat_add_to(&C12, M3);
mat_add_to(&C22, M3);
}
{
Matrix *M4 = mat_strassen_mul_impl_(T4, &A22, mat_diff_to(T5, &B21, &B11), T3);
mat_add_to(&C11, M4);
mat_add_to(&C21, M4);
}
free_all(&T3, &T4, &T5, NULL);
}
free_all_temp(A, B, T, NULL);
return C;
}
The multiplication kernel and the API wrappers:
static void unpack_row_major(float *const restrict B, const float *const restrict A, int const Ars);
static void unpack_col_major(float *const restrict B, const float *const restrict A, int const Ars);
#if 0
static void pack_row_major(float *const restrict B, const float *const restrict A, int const Brs);
static void pack_col_major(float *const restrict B, const float *const restrict A, int const Brs);
#endif
static void mat_mul_impl_(float *restrict C, const float *restrict A, const float *restrict B,
int C_row_stride, int A_row_stride, int B_row_stride)
{
enum { N = KERNEL_SIZE };
float AA[N*N], BB[N*N];
unpack_row_major(AA, A, A_row_stride);
if (TRANSPOSE_B)
unpack_row_major(BB, B, B_row_stride);
else
unpack_col_major(BB, B, B_row_stride);
for (int i = 0; i < N; ++i) {
for (int j = 0; j < N; ++j) {//00
float accum = 0;
for (int k = 0; k < N; ++k) {
accum += AA[i*N+k] * BB[j*N+k];
}
C[i*C_row_stride+j] = accum;
}
}
}
static void unpack_row_major(float *const restrict B, const float *const restrict A, int const Ars)
{
const int N = KERNEL_SIZE;
for (int i = 0; i < N; ++i)
for (int j = 0; j < N; ++j)
B[i*N+j] = A[i*Ars+j];
}
static void unpack_col_major(float *const restrict B, const float *const restrict A, int const Ars)
{
const int N = KERNEL_SIZE;
for (int i = 0; i < N; ++i)
for (int j = 0; j < N; ++j)
B[i*N+j] = A[j*Ars+i];
}
#if 0
static void pack_row_major(float *const restrict B, const float *const restrict A, int const Brs)
{
const int N = KERNEL_SIZE;
for (int i = 0; i < N; ++i)
for (int j = 0; j < N; ++j)
B[i*Brs+j] = A[i*N+j];
}
static void pack_col_major(float *const restrict B, const float *const restrict A, int const Brs)
{
const int N = KERNEL_SIZE;
for (int i = 0; i < N; ++i)
for (int j = 0; j < N; ++j)
B[j*Brs+i] = A[i*N+j];
}
#endif
Matrix *mat_strassen_mul_to(Matrix *C, Matrix *A, Matrix *B) {
mat_use_check(C);
mat_use_check(A);
mat_use_check(B);
assert(C->n == A->n && C->n == B->n);
assert(C->n >= KERNEL_SIZE);
if (TRANSPOSE_B)
mat_transpose(B);
if (C->n <= 64) {
printf("A\n");
mat_print(A);
printf("B\n");
mat_print(B);
}
mat_strassen_mul_impl_(C, A, B, NULL);
if (C->n <= 64) {
printf("C\n");
mat_print(C);
}
if (TRANSPOSE_B)
mat_transpose(B);
return C;
}
_nodiscard_ Matrix *mat_strassen_mul(Matrix *A, Matrix *B) {
mat_use_check(A);
mat_use_check(B);
Matrix *C = mat_create(A->n);
mat_strassen_mul_to(C, A, B);
return C;
}
Now for addition/subtraction:
//
// Addition/subtraction
//
Matrix *mat_sum_to(Matrix *C, Matrix *A, Matrix *B) {
mat_use_check(C);
mat_use_check(A);
mat_use_check(B);
assert(C->n == A->n && C->n == B->n);
float *restrict c = C->ptr, * restrict a = A->ptr, * restrict b = B->ptr;
int const N = A->n;
int const Ars = A->row_stride, Brs = B->row_stride, Crs = C->row_stride;
for (int i = 0; i < N; ++i) {
for (int j = 0; j < N; ++j)
c[i*Crs+j] = a[i*Ars+j] + b[i*Brs+j];
}
free_all_temp(A, B, NULL);
return C;
}
_nodiscard_ Matrix *mat_sum(Matrix *A, Matrix *B) {
return mat_sum_to(mat_create(A->n), A, B);
}
Matrix *mat_add_to(Matrix *C, Matrix *B) {
mat_use_check(C);
mat_use_check(B);
assert(C->n == B->n);
float *restrict c = C->ptr, *restrict b = B->ptr;
int const N = C->n;
int const Brs = B->row_stride, Crs = C->row_stride;
for (int i = 0; i < N; ++i)
for (int j = 0; j < N; ++j)
c[i*Crs+j] += b[i*Brs+j];
free_temp(B);
return C;
}
Matrix *mat_diff_to(Matrix *C, Matrix *A, Matrix *B) {
mat_use_check(C);
mat_use_check(A);
mat_use_check(B);
assert(C->n == A->n && C->n == B->n);
int const N = A->n, Ars = A->row_stride, Brs = B->row_stride, Crs = C->row_stride;
float *restrict c = C->ptr, *restrict a = A->ptr, *restrict b = B->ptr;
for (int i = 0; i < N; ++i)
for (int j = 0; j < N; ++j)
c[i*Crs+j] = a[i*Ars+j] - b[i*Brs+j];
free_all_temp(A, B, NULL);
return C;
}
_nodiscard_ Matrix *mat_diff(Matrix *A, Matrix *B) {
return mat_diff_to(mat_create(A->n), A, B);
}
Matrix *mat_sub_from(Matrix *C, Matrix *B) {
mat_use_check(C);
mat_use_check(B);
assert(C->n == B->n);
float *restrict c = C->ptr, *restrict b = B->ptr;
int const N = C->n, Brs = B->row_stride, Crs = C->row_stride;
for (int i = 0; i < N; ++i)
for (int j = 0; j < N; ++j)
c[i*Crs+j] -= b[i*Brs+j];
free_temp(B);
return C;
}
And some ways of filling the matrices with values:
//
// Misc Value Setting
//
_nodiscard_ size_t mat_num_bytes(const Matrix *A) {
mat_use_check(A);
return A ? sizeof(*A) + sizeof(float) * A->n * A->row_stride : 0;
}
Matrix *mat_zero(Matrix *M) {
mat_use_check(M);
int const N = M->n;
float *restrict m = M->ptr;
for (int i = 0; i < N; ++i) {
memset(m, 0, sizeof(float) * N);
m += M->row_stride;
}
return M;
}
_nodiscard_ Matrix *mat_zeroed(int n) { return mat_zero(mat_create(n)); }
Matrix *mat_randomize(Matrix *M) {
mat_use_check(M);
float *restrict m = M->ptr;
const int N = M->n, Mrs = M->row_stride;
for (int i = 0; i < N; ++i) {
for (int j = 0; j < N; ++j)
m[i*Mrs+j] = -1. + 2.*((float)rand())/RAND_MAX;
}
return M;
}
_nodiscard_ Matrix *mat_randomized(int n) { return mat_randomize(mat_create(n)); }
Matrix *mat_row_seq(Matrix *M) {
mat_use_check(M);
mat_zero(M);
float *restrict m = M->ptr;
const int N = M->n;
for (int i = 0; i < N; ++i)
m[i] = i;
return M;
}
Matrix *mat_col_seq(Matrix *M) {
mat_use_check(M);
mat_zero(M);
float *restrict m = M->ptr;
const int N = M->n, Mrs = M->row_stride;
for (int i = 0; i < N; ++i)
m[i*Mrs] = i;
return M;
}
Matrix *mat_transpose(Matrix *M) {
mat_use_check(M);
const int N = M->n, Mrs = M->row_stride;
float *const restrict m = M->ptr;
for (int i = 0; i < N; ++i) {
for (int j = i+1; j < N; ++j) {
float a = m[i*Mrs+j];
m[i*Mrs+j] = m[j*Mrs+i];
m[j*Mrs+i] = a;
}
}
return M;
}
Matrix *mat_copy_to(Matrix *M, Matrix *A) {
mat_use_check(M);
mat_use_check(A);
assert(M->n == A->n);
if (mat_same_layouts(M, A)) {
memcpy(M->ptr, A->ptr, mat_num_bytes(M));
} else {
float *restrict m = M->ptr, *restrict a = A->ptr;
int const N = M->n, Ars = A->row_stride, Mrs = M->row_stride;
for (int i = 0; i < N; ++i)
for (int j = 0; j < N; ++j)
m[i*Mrs+j] = a[i*Ars+j];
}
free_temp(A);
return M;
}
Now - the memory management:
//
// Matrix Creation/Destruction
//
//! A modifier used to pass a temporary matrix as a matrix argument - the
//! called function will treat this matrix as a temporary one and free it when
//! it returns.
Matrix *temp(Matrix *M) {
mat_use_check(M);
assert(M->tmp_count >= 0);
M->tmp_count ++;
return M;
}
inline size_t mat_alloc_size(const int n) {
return sizeof(Matrix) + sizeof(float) * n * n;
}
__attribute__((noreturn)) static void out_of_memory(void) {
fprintf(stderr, "Out of memory\n");
fflush(stderr);
abort();
}
_nodiscard_ Matrix *mat_create(int const n) {
size_t const bytes = mat_alloc_size(n);
Matrix *const M = malloc(bytes);
if (TRACK_ALLOCS) {
printf("Allocating (%ld) %dx%d %p\n", obj_track(TE_COUNT, NULL, 0), n, n, M);
fflush(stdout);
}
if (!M) out_of_memory();
bool ok = obj_track(TE_CREATE, M, bytes);
assert(ok);
M->shown = NULL;
M->ptr = (void*)(M+1);
M->n = n;
M->row_stride = n;
M->tmp_count = 0;
return M;
}
void mat_free(Matrix **M) {
Matrix *mp = *M;
if (!mp || mp->shown) return;
mat_use_check(mp);
const size_t bytes = mat_alloc_size(mat_shown(mp)->n);
if (TRACK_ALLOCS) {
printf("Freeing %s (%ld) %dx%d %p\n",
mp->shown ? "View" : "Matrix", obj_track(TE_COUNT, NULL, 0),
mp->n, mp->n, mp);
fflush(stdout);
}
free(mp);
bool ok = obj_track(TE_DESTROY, mp, bytes);
assert(ok);
*M = mp = NULL;
}
void free_all(Matrix **M, ...) {
va_list args;
va_start(args, M);
while (M) {
mat_free(M);
M = va_arg(args, Matrix **);
}
va_end(args);
}
void free_temp(Matrix *M) {
if (!M) return;
if (!M->tmp_count) return;
assert(M->tmp_count > 0);
if (!--M->tmp_count)
mat_free(&M);
}
void free_all_temp(Matrix *M, ...) {
va_list args;
va_start(args, M);
while(M) {
free_temp(M);
M = va_arg(args, Matrix *);
}
va_end(args);
}
And ways of querying/outputting the matrix:
//
// Matrix Query and Output
//
static _nodiscard_ bool mat_same_layouts(const Matrix *A, const Matrix *B) {
mat_use_check(A);
mat_use_check(B);
return A->n == B->n && A->row_stride == B->row_stride;
}
void mat_print(const Matrix *M) {
mat_use_check(M);
float *m = M->ptr;
for (int i = 0; i < M->n; ++i) {
for (int j = 0; j < M->n; ++j) printf("%.0f ", m[j]);
printf("\n");
m += M->row_stride;
}
fflush(stdout);
}
void mat_dump(const Matrix *M) {
mat_use_check(M);
if (!TRACK_DUMPS) return;
if (!M) {
printf("Null\n");
} else {
const char *kind = !M->shown ? "Matrix" : "View";
printf("%s %dx%d <->%d %p", kind, M->n, M->n, M->row_stride, M->ptr);
if (M->shown) printf(" ..%p", M->shown->ptr);
printf("\n");
}
fflush(stdout);
}
And now a quite important feature: matrix views. This allows creating "matrices" that don't own their data, but merely act as views onto another matrix. This is leveraged in Strassen multiplication to get rid of lots of memory copying and allocations:
//
// Views of a Matrix
//
static void track_view(const Matrix *V, const char *kind) {
if (TRACK_VIEWS) {
printf("New %s %dx%d <->%d %p\n", kind, V->n, V->n, V->row_stride, V->ptr);
fflush(stdout);
}
}
//! Returns a sub-block *view* of a given matrix. The block's index is i,j (0-based),
//! out of an nxn square of blocks. Thew view is by-value and is meant to be
//! kept by value. It doesn't allocate.
_nodiscard_ Matrix mat_block_view(const Matrix *M, int i, int j, int nbl) {
mat_use_check(M);
const Matrix *shown = mat_shown(M);
Matrix view = { .shown = (Matrix*)shown };
view.n = M->n / nbl;
view.row_stride = M->row_stride;
view.ptr = M->ptr + ((size_t)i * view.n * view.row_stride) + ((size_t)j * view.n);
track_view(&view, "View");
return view;
}
//! Returns a sub-block linearized view of a given matrix, i.e. the sub-blocks
//! have the smallest possible row_stride. Useful for cache locality when reusing
//! temporary allocations. The source matrix must be contiguous, i.e. it can't be
//! a mat_block_view.
_nodiscard_ Matrix mat_linear_block_view(const Matrix *M, int i, int j, int nbl) {
mat_use_check(M);
assert(M->row_stride == M->n);
const Matrix *shown = mat_shown(M);
Matrix view = { .shown = (Matrix*)shown };
view.n = M->n / nbl;
view.row_stride = view.n;
view.ptr = M->ptr + ((size_t)i * nbl + (size_t)j) * view.n * view.n;
track_view(&view, "Linear View");
return view;
}
And a little example code:
//
// Example/Test
//
typedef struct timeval timeval;
_nodiscard_ timeval get_time(void) {
timeval result;
gettimeofday(&result, 0);
return result;
}
_nodiscard_ double time_delta(const timeval *start, const timeval *end) {
double const t1 = start->tv_sec + start->tv_usec/1e6;
double const t2 = end->tv_sec + end->tv_usec/1e6;
return t2-t1;
}
int main()
{
size_t const N = 2048;
#if 1
Matrix *A = mat_randomize(mat_create(N));
Matrix *B = mat_randomize(mat_create(N));
#else
Matrix *A = mat_row_seq(mat_create(N));
Matrix *B = mat_col_seq(mat_create(N));
#endif
Matrix *C = mat_create(N);
printf("Memory used for A,B,C matrices: %lu\n", obj_track(TE_ALLOC, NULL, 0));
timeval start = get_time();
mat_strassen_mul_to(C, A, B);
timeval end = get_time();
free_all(&C, &B, &A, NULL);
printf("Number of entries to Strassen multiplication: %d\n", strassen_entry_count);
printf("Total wall time: %gs\n", time_delta(&start, &end));
printf("Maximum allocated matrices/memory: %lu / %lu\n",
obj_track(TE_MAX_COUNT, NULL, 0), obj_track(TE_MAX_ALLOC, NULL, 0));
printf("Matrices left: %lu\n", obj_track(TE_COUNT, NULL, 0));
assert(obj_track(TE_ISEMPTY, NULL, 0));
}
And finally the nitty-gritty of object lifetime diagnostics. This is entirely optional, but was helpful as I've modified and simplified the memory management code throughout my experiments:
//
// Diagnostic Object Tracking
//
#if TRACK_MEMORY_USE || TRACK_LIFETIME
struct {
const void *ptr;
} typedef ObjEntry;
struct {
ObjEntry *objects;
size_t capacity;
size_t count;
size_t alloc;
size_t max_alloc;
size_t max_count;
bool is_sorted;
} typedef ObjState;
bool obj_count(const void *obj, size_t size, ObjState *ost);
bool obj_uncount(const void *obj, size_t size, ObjState *ost);
bool obj_push_back(const void *obj, size_t size, ObjState *ost);
bool obj_find(const void *obj, ObjState *ost);
bool obj_remove(const void *obj, size_t size, ObjState *ost);
size_t obj_track(enum TrackEvent event, const void *obj, size_t param) {
static ObjState ost;
switch (event) {
#if TRACK_MEMORY_USE && !TRACK_LIFETIME
case TE_CREATE:
return obj_count(obj, param, &ost);
case TE_DESTROY:
return obj_uncount(obj, param, &ost);
case TE_USE:
return !!obj;
#else
case TE_CREATE:
return !obj_find(obj, &ost) && obj_push_back(obj, param, &ost);
case TE_USE:
return obj_find(obj, &ost);
case TE_DESTROY:
return obj_remove(obj, param, &ost);
#endif
case TE_ISEMPTY:
return !ost.count;
case TE_COUNT:
return ost.count;
case TE_ALLOC:
return ost.alloc;
case TE_MAX_COUNT:
return ost.max_count;
case TE_MAX_ALLOC:
return ost.max_alloc;
default:
return false;
}
}
bool obj_count(const void *obj, size_t size, ObjState *ost) {
if (!obj || !size) return false;
++ost->count;
ost->alloc += size;
if (ost->count > ost->max_count) ost->max_count = ost->count;
if (ost->alloc > ost->max_alloc) ost->max_alloc = ost->alloc;
return true;
}
bool obj_uncount(const void *obj, size_t size, ObjState *ost) {
if (!obj || !size) return false;
--ost->count;
ost->alloc -= size;
return true;
}
bool obj_push_back(const void *obj, size_t size, ObjState *ost) {
if (!ost->capacity) {
ost->capacity = 32;
ost->objects = malloc(sizeof(ObjEntry) * ost->capacity);
}
else if (ost->capacity == ost->count) {
ost->capacity *= 2;
ost->objects = realloc(ost->objects, sizeof(ObjEntry) * ost->capacity);
if (!ost->objects) out_of_memory();
}
if (!obj_count(obj, size, ost))
return false;
ost->objects[ost->count-1] = (ObjEntry){ .ptr = obj };
if (ost->count == 1) {
ost->is_sorted = true;
} else {
ObjEntry *second_to_last = &(ost->objects[ost->count - 2]);
ost->is_sorted = ost->is_sorted && obj > second_to_last->ptr;
}
return true;
}
static int ptr_comp(const void *a, const void *b) {
const ObjEntry *obj1 = a;
const ObjEntry *obj2 = b;
ssize_t diff = obj1->ptr - obj2->ptr;
return diff < 0 ? - 1 : diff > 0 ? 1 : 0;
}
int obj_lookup(const void *obj, ObjState *ost) {
if (!ost->is_sorted) {
qsort(ost->objects, ost->count, sizeof(ObjEntry), ptr_comp);
ost->is_sorted = true;
}
if (ost->count > 1) {
// Sanity check: the first two objects must be sorted, at least.
assert(ost->objects[0].ptr < ost->objects[1].ptr);
}
const ObjEntry *found =
bsearch(&obj, ost->objects, ost->count, sizeof(ObjEntry), ptr_comp);
return (!found) ? -1 : (found - ost->objects);
}
bool obj_find(const void *obj, ObjState *ost) { return obj_lookup(obj, ost) >= 0; }
bool obj_erase(int pos, size_t size, ObjState *ost) {
assert(pos >= -1 && pos < ost->count);
if (pos == -1) return false;
if (!obj_uncount(ost->objects[pos].ptr, size, ost))
return false;
if (pos < (ost->count))
memmove(ost->objects + pos, ost->objects + pos + 1,
sizeof(ObjEntry) * (ost->count - pos));
return true;
}
bool obj_remove(const void *obj, size_t size, ObjState *ost) {
int index = obj_lookup(obj, ost);
return obj_erase(index, size, ost);
}
#endif // TRACK_MEMORY_USE || TRACK_LIFETIME
// complete compileable example ends
That's all, folks :)