/* * som メイン (main.c) * * A. Date * 2015.5.18 */ #include #include #include #include "mnist.h" #include "som.h" #include "irl_utility.h" #include "irl_gnuplot.h" void test_mnist(){ int i; int n = MNIST_NUM_IMAGES; double **digit; int *label; digit = mnist_read_images(); label = mnist_read_labels(); for(i=0; i < n; i++){ if ( i % 20 == 0){ printf("\n"); } printf("%d:%d ", i, (int)label[i] ); } } void test_som(){ int dim_som = 2; // SOM の配列(神経場)の次元 int n1 = 784; // 入力信号の次元 int n2 = 64; // SOM の素子数 int n_trainings = 1000; int output_format = 0; int t; double mu = 0.1; double sigma = 3.0; SOM *som; int n = MNIST_NUM_IMAGES; double **digit; int *label; double *img; // gnuplot 表示用 char buf[100]; int MARGIN_X = 2; int MARGIN_Y = 2; int N_XUNITS = MNIST_X_UNITS; int N_YUNITS = MNIST_Y_UNITS; img = alloc_1d_dbl( (MARGIN_X + N_XUNITS + MARGIN_X)*(MARGIN_Y + (N_YUNITS + MARGIN_Y))*n2); // 入力信号の集合を読みこむ. digit = mnist_read_images(); label = mnist_read_labels(); // 回路を作る. som = som_new(n1, n2, dim_som); // 回路の初期化 som_init_m(som, n1, n2, mu, sigma, digit); irl_init_gnuplot (output_format); // som に入力信号を提示して,学習させる. for(t=0; tm, img, buf, n1, n2, output_format); } // 学習がある程度終了後: // 2層から3層目の結合を学習. free_som(som); } void som_init_m(SOM *this, int n1, int n2, double mu, double sigma, double **digit){ int i,j; for(i=0; im[i][j] = 0.1*nrand(); // this->m[i][j] = drand48(); this->m[i][j] = digit[i][j]; } } } SOM* som_new(int n1, int n2, int dim) { SOM* this; int i, j, k; this = (SOM*) malloc (sizeof(SOM)); if (this == NULL) return NULL; this->n1 = n1; this->n2 = n2; this->dim = dim; this->x = alloc_1d_dbl(n1); for(j=0; jx[j] = 0.0; } this->m = alloc_2d_dbl(n2,n1); this->dm = alloc_2d_dbl(n2,n1); for(i=0; im[i][j] = 0.0; this->dm[i][j] = 0.0; } } this->r = alloc_2d_dbl(n2,dim); for(i=0; ir[i][k] = drand48(); } } return this; } void free_som(SOM* this) { int i; int n2 = this->n2; free((char *) this->x); for (i=0; im[i]); free((char *) this->dm[i]); free((char *) this->r[i]); } free((char *) this->m); free((char *) this->dm); free((char *) this->r); } int main (int argc, char *argv[] ){ int i; int seed = RAND_SEED; for (i=1; i