/* irl_mnist_show_samples_nn.c コンパイルの仕方 % gcc irl_mnist_show_samples_nn.c irl_utility.c -lm ・神経回路モデルを使って実験する際のプロトタイプコード ・MNIST データの読み込み. ・gnuplot を使ってのデータの表示. ・とくに標準的なライブラリしか使っていないので20年後もそのまま動く(はず)! ・ゴミが混ざっている可能性大. A. Date 2014.5.26 */ #include #include #include #include #include #include #include "irl_mnist_nn.h" #include "irl_utility.h" void test_nn_mnist(int generate_fugure); void init_gnuplot (int generate_figure); void plot3d_gnuplot (NN* nn, double *img, char *title, int n_visible, int n_hidden, int generate_figure); void write_multi_images(NN* nn, double *img, int nx, int ny,int n_visible, int n_hidden ); double uniform(double, double); FILE *gp; char buf[100]; double uniform(double min, double max) { return rand() / (RAND_MAX + 1.0) * (max - min) + min; } void NN__construct(NN* this, int N, int n_visible, int n_hidden, double **W, double *hbias, double *vbias) { int i, j; double a = 1.0 / n_visible; this->N = N; this->n_visible = n_visible; this->n_hidden = n_hidden; if(W == NULL) { this->W = (double **)malloc(sizeof(double*)*n_hidden); this->W[0] = (double *)malloc(sizeof(double)*n_visible*n_hidden); for(i=0; iW[i] = this->W[0]+i*n_visible; } for(i=0; iW[i][j] = uniform(-a, a); } } } else{ this->W = W; } if(hbias == NULL) { this->hbias = (double *)malloc(sizeof(double)*n_hidden); for(i=0; ihbias[i] = 0; } } else{ this->hbias = hbias; } if(vbias == NULL) { this->vbias = (double *)malloc(sizeof(double) * n_visible); for(i=0; ivbias[i] = 0; } } else{ this->vbias = vbias; } } void NN__destruct(NN* this) { free(this->W[0]); free(this->W); free(this->hbias); free(this->vbias); } /* Read the training images into memory */ void read_mnist_images(unsigned char **digit){ int i, fd; static int num[10]; if ((fd=open(IMAGE_FILE,O_RDONLY))==-1){ printf("couldn't open image file"); exit(0); } /* skip headers */ read(fd,num, 4*sizeof(int) ); for (i=0; iW; min = 0.0; for (i = 0; i < n_hidden; i++){ for (j=1; j <= n_visible; j++){ if ( w[i][j] < min ){ min = w[i][j]; } } } for (i=0; in_hidden; int n_inputs = this->n_visible; for (i=0; i < n_hidden; i++){ m = (int)(drand48()*(double)train_N); m = m % n_hidden; for (j=0; j < n_inputs; j++){ this->W[i][j] = train_X[i][j]; } } } void test_mnist_nn(int generate_pdf) { int i, j, epoch; double mean; int training_epochs = 1; // ← 実際に学習実験する場合はここの値を増やす int train_N = MNIST_NUM_IMAGES; int n_visible = MNIST_SIZE; int n_hidden = 100; unsigned char **digit; double **train_X; double *img; int MARGIN_X = 2; int MARGIN_Y = 2; int N_XUNITS = 28; // MNIST データは 28x28 int N_YUNITS = 28; char buf[100]; // for gnuplot img = alloc_1d_dbl( (MARGIN_X + N_XUNITS + MARGIN_X)*(MARGIN_Y + (N_YUNITS + MARGIN_Y))*n_hidden*2 ); // training data digit = alloc_2d_uchar(MNIST_NUM_IMAGES, MNIST_SIZE); train_X = alloc_2d_dbl(MNIST_NUM_IMAGES, MNIST_SIZE); read_mnist_images(digit); mean = 0.0; for(i=0; i < MNIST_NUM_IMAGES; i++){ for(j=0; j < MNIST_SIZE; j++){ mean += (double)digit[i][j]; } } mean = mean/( (double)(MNIST_NUM_IMAGES*MNIST_SIZE) ); for(i=0; i < MNIST_NUM_IMAGES; i++){ for(j=0; j < MNIST_SIZE; j++){ train_X[i][j] = (double)digit[i][j] - mean; } normalize_input(train_X[i], MNIST_SIZE); } init_gnuplot (generate_pdf); // construct neural network model NN nn; NN__construct(&nn, train_N, n_visible, n_hidden, NULL, NULL, NULL); NN_init_weights_mnist(&nn, train_N, train_X); // train for(epoch=0; epoch<=training_epochs; epoch++) { if ( epoch % 1000 == 0){ // sprintf (buf, "t=%7d", epoch); sprintf (buf, "MNIST data"); plot3d_gnuplot (&nn, img, buf, n_visible, n_hidden, generate_pdf); } i = (int)((double)MNIST_NUM_IMAGES*drand48()); // ここで学習の関数を呼び出して学習. } // destruct NN NN__destruct(&nn); free(digit[0]); free(digit); free(train_X[0]); free(train_X); free(img); } int main (int argc, char *argv[]){ int i,j; long seed = RAND_SEED; int *error = NULL; int set_random_weight = 1; int generate_figure = 0; for (i = 1; i < argc; i++) { switch (*(argv[i] + 1)) { case 'r': seed = atoi (argv[++i]); break; case 'p': generate_figure = atoi (argv[++i]); break; case 'w': set_random_weight = atoi (argv[++i]); break; default: fprintf (stderr, "Usage : %s\n", argv[0]); fprintf (stderr, "\t-r : random-seed(%ld)\n", seed); fprintf (stderr, "\t-p : generate figure(1:pdf 2:gif 3:eps): (%d)\n", generate_figure); fprintf (stderr, "\t-w : set random weight(%d)\n", set_random_weight); exit (0); break; } } srand48 (seed); test_mnist_nn(generate_figure); return 0; }