/* Backpropagation demo by gnuplot 2006.5.18 A.Date */ #include #include /* drand48 */ #include /* log */ #define MAX_UNITS 100 #define MAX_DATA 1000 #define N_DIM 3 /* (dimension of input) + 1 */ /* 以下のパラメータ値を変えて試してみる */ #define RAND_SEED 12345678 /* 乱数の種 */ #define SLEEP 20000 /* 表示スピード */ int n_hidden = 3; /* 中間層の素子の数 */ int n_trials = 200; /* 学習回数 */ double eta = 0.1; /* 学習係数 */ double lambda=1.0; /* シグモイド関数のパラメータ */ char data_file1[] = "apples3.dat"; char data_file2[] = "oranges3.dat"; char error_file[] = "errors3.dat"; int n_data, n_data1, n_data2; /* a random sample from standard normal distribution */ /* modified from C 言語による最新アルゴリズム辞典 p.135 */ double nrnd() { static int sw=0; static double r1,r2,s; if (sw==0){ sw=1; do { r1=2.0*drand48()-1.0; r2=2.0*drand48()-1.0; s=r1*r1+r2*r2; } while (s>1.0 || s==0.0); s=sqrt(-2.0*log(s)/s); return(r1*s); } else { sw=0; return(r2*s); } } double sigmoid(double x){ return 1.0/(1.0 + exp(-lambda*x)); } int main (argc, argv) int argc; char *argv[]; { FILE *gp, *gp2, *fp, *fp2; int a,i,j,k,l,n,t; int seed = 1234523890; /* set your favorite number */ int n_data; int teacher, answer; int e3; double d1,d2; double dw[MAX_UNITS][N_DIM]; double w[MAX_UNITS][N_DIM]; /* w[i][j], y_1,...y_i */ double data[MAX_DATA][N_DIM]; double u[MAX_UNITS], ui, v, z, e, e2, eta2; double r0, r[MAX_UNITS]; double y[MAX_UNITS]; /* output of hidden units */ double s[MAX_UNITS]; /* connection weights of output unit */ double ds[MAX_UNITS]; double x_start = -1.0, x_end=5.0; double y_start = -3.0, y_end=5.0; for (i=1; i 0.5){ answer = 1; } else{ answer = 0; } e3 += (teacher - answer)*(teacher - answer); e = (double)teacher - z; e2 = e2 + e*e; r0 = -2.0*lambda*e*z*(1.0-z); for ( i=0; i<= n_hidden; i++) { ds[i] = ds[i] - eta*r0*y[i]; } for ( i=1; i 0.5 ){ fprintf(fp2, "%.3lf \t %.3lf\n", d1, d2); } } } fflush(fp2); fclose(fp2); fprintf(gp2, "plot '%s' with points 1, '%s' with points 3,", data_file1, data_file2); if ( e3==0 ){ fprintf(gp2, "'%s' with dots 5\n",error_file); } else{ fprintf(gp2, "'%s' with dots 72\n",error_file); } fflush(gp2); } /* update weights */ s[0] = s[0] + ds[0]; for ( i=1; i<=n_hidden; i++) { s[i] = s[i] + ds[i]; w[i][0] = w[i][0] + dw[i][0]; w[i][1] = w[i][1] + dw[i][1]; w[i][2] = w[i][2] + dw[i][2]; } /* plot at u=0 (hyper planes) */ fprintf(gp, "set title 't =%d, error:%3d/%3d'\n", t,e3,n_data); fprintf(gp, "plot '%s' with points 1, '%s' with points 3", data_file1, data_file2); for ( i=1; i<=n_hidden; i++) { fprintf(gp, ",%.3lf*x + %.3lf with lines ", -w[i][1]/w[i][2], -w[i][0]/w[i][2]); if ( w[i][2] > 0 ){ fprintf(gp, "1"); } else { fprintf(gp, "3"); } } fprintf(gp,"\n"); fflush(gp); usleep(SLEEP); } /* end of a trial */ } }