博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
吴裕雄 PYTHON 神经网络——TENSORFLOW 双隐藏层自编码器设计处理MNIST手写数字数据集并使用TENSORBORD描绘神经网络数据2...
阅读量:4979 次
发布时间:2019-06-12

本文共 4513 字,大约阅读时间需要 15 分钟。

import osimport tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' batch_size = 128  # batch容量display_step = 1  # 展示间隔learning_rate = 0.01  # 学习率training_epochs = 20  # 训练轮数,1轮等于n_samples/batch_sizeexample_to_show = 10  # 展示图像数目 n_hidden1_units = 256  # 第一隐藏层n_hidden2_units = 128  # 第二隐藏层n_input_units = 784n_output_units = n_input_units def variable_summaries(var):     with tf.name_scope('summaries'):        mean = tf.reduce_mean(var)        tf.summary.histogram('mean', mean)        with tf.name_scope('stddev'):            stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))            tf.summary.scalar('stddev', stddev)  # 注意,这是标量            tf.summary.scalar('max', tf.reduce_max(var))            tf.summary.scalar('min', tf.reduce_min(var))            tf.summary.histogram('histogram', var)def WeightsVariable(n_in,n_out,name_str):    return tf.Variable(tf.random_normal([n_in,n_out]),dtype=tf.float32,name=name_str) def biasesVariable(n_out,name_str):    return tf.Variable(tf.random_normal([n_out]),dtype=tf.float32,name=name_str) def encoder(x_origin,activate_func=tf.nn.sigmoid):    with tf.name_scope('Layer1'):        Weights = WeightsVariable(n_input_units,n_hidden1_units,'Weights')        biases = biasesVariable(n_hidden1_units,'biases')        x_code1 = activate_func(tf.add(tf.matmul(x_origin,Weights),biases))        variable_summaries(Weights)         variable_summaries(biases)     with tf.name_scope('Layer2'):        Weights = WeightsVariable(n_hidden1_units,n_hidden2_units,'Weights')        biases = biasesVariable(n_hidden2_units,'biases')        x_code2 = activate_func(tf.add(tf.matmul(x_code1,Weights),biases))        variable_summaries(Weights)         variable_summaries(biases)     return x_code2 def decode(x_code,activate_func=tf.nn.sigmoid):    with tf.name_scope('Layer1'):        Weights = WeightsVariable(n_hidden2_units,n_hidden1_units,'Weights')        biases = biasesVariable(n_hidden1_units,'biases')        x_decode1 = activate_func(tf.add(tf.matmul(x_code,Weights),biases))        variable_summaries(Weights)         variable_summaries(biases)     with tf.name_scope('Layer2'):        Weights = WeightsVariable(n_hidden1_units,n_output_units,'Weights')        biases = biasesVariable(n_output_units,'biases')        x_decode2 = activate_func(tf.add(tf.matmul(x_decode1,Weights),biases))        variable_summaries(Weights)         variable_summaries(biases)    return x_decode2 with tf.Graph().as_default():    with tf.name_scope('Input'):        X_input = tf.placeholder(tf.float32,[None,n_input_units])    with tf.name_scope('Encode'):        X_code = encoder(X_input)    with tf.name_scope('decode'):        X_decode = decode(X_code)    with tf.name_scope('loss'):        loss = tf.reduce_mean(tf.pow(X_input - X_decode,2))    with tf.name_scope('train'):        Optimizer = tf.train.RMSPropOptimizer(learning_rate)        train = Optimizer.minimize(loss)    # 标量汇总    with tf.name_scope('LossSummary'):        tf.summary.scalar('loss',loss)        tf.summary.scalar('learning_rate',learning_rate)    # 图像展示    with tf.name_scope('ImageSummary'):        image_original = tf.reshape(X_input,[-1, 28, 28, 1])        image_reconstruction = tf.reshape(X_decode, [-1, 28, 28, 1])        tf.summary.image('image_original', image_original, 9)        tf.summary.image('image_recinstruction', image_reconstruction, 9)    # 汇总    merged_summary = tf.summary.merge_all()     init = tf.global_variables_initializer()     writer = tf.summary.FileWriter(logdir='E:\\tensorboard\\logsssxx', graph=tf.get_default_graph())    writer.flush()     mnist = input_data.read_data_sets('E:\\MNIST_data\\', one_hot=True)     with tf.Session() as sess:        sess.run(init)        total_batch = int(mnist.train.num_examples / batch_size)        for epoch in range(training_epochs):            for i in range(total_batch):                batch_xs,batch_ys = mnist.train.next_batch(batch_size)                _,Loss = sess.run([train,loss],feed_dict={X_input: batch_xs})                Loss = sess.run(loss,feed_dict={X_input: batch_xs})            if epoch % display_step == 0:                print('Epoch: %04d' % (epoch + 1),'loss= ','{:.9f}'.format(Loss))                summary_str = sess.run(merged_summary,feed_dict={X_input: batch_xs})                writer.add_summary(summary_str,epoch)                 writer.flush()         writer.close()        print('训练完毕!')

 

 

转载于:https://www.cnblogs.com/tszr/p/10864351.html

你可能感兴趣的文章
[django]form的content-type(mime)
查看>>
仿面包旅行个人中心下拉顶部背景放大高斯模糊效果
查看>>
C# 小叙 Encoding (二)
查看>>
CSS自学笔记(14):CSS3动画效果
查看>>
项目应用1
查看>>
基本SCTP套接字编程常用函数
查看>>
C 编译程序步骤
查看>>
[Git] 005 初识 Git 与 GitHub 之分支
查看>>
【自定义异常】
查看>>
pip install 后 importError no module named "*"
查看>>
springmvc跳转方式
查看>>
IOS 第三方管理库管理 CocoaPods
查看>>
背景色渐变(兼容各浏览器)
查看>>
MariaDB 和 MySQL 比较
查看>>
MYSQL: 1292 - Truncated incorrect DOUBLE value: '184B3C0A-C411-47F7-BE45-CE7C0818F420'
查看>>
SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder".
查看>>
springMVC Controller 参数映射
查看>>
【bzoj题解】2186 莎拉公主的困惑
查看>>
Protocol Buffer学习笔记
查看>>
Update 语句
查看>>