Skip to content

Instantly share code, notes, and snippets.

@Jwata
Last active July 20, 2018 06:42
Show Gist options
  • Select an option

  • Save Jwata/0a264309824d2cd6b8fb97b202858604 to your computer and use it in GitHub Desktop.

Select an option

Save Jwata/0a264309824d2cd6b8fb97b202858604 to your computer and use it in GitHub Desktop.
import tensorflow as tf
# model1
def model1_fn(inputs, labels, learning_rate=0.001):
preds = tf.layers.dense(inputs, 1)
cost = tf.losses.mean_squared_error(labels, preds)
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
return optimizer, cost, preds
# model2
def model2_fn(inputs, labels, learning_rate=0.001):
preds = tf.layers.dense(inputs, 1)
cost = tf.losses.mean_squared_error(labels, preds)
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
return optimizer, cost, preds
inputs1 = tf.placeholder(shape=(None, 1), dtype=tf.float32)
labels1 = tf.placeholder(shape=(None, 1), dtype=tf.float32)
op1, cost1, preds1 = model1_fn(inputs1, labels1)
saver = tf.train.Saver()
inputs2 = tf.placeholder(shape=(None, 1), dtype=tf.float32)
labels2 = tf.placeholder(shape=(None, 1), dtype=tf.float32)
op2, cost2, preds2 = model2_fn(inputs2, labels2)
# train model1
x1 = [[1], [2]]
y1 = [[3], [5]]
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(1000):
_, cost_, preds_ = sess.run(
[op1, cost1, preds1], feed_dict={
inputs1: x1,
labels1: y1
})
print('Cost:', cost_)
print('Preds:', preds_)
saver.save(sess, 'models/test_model1')
class Predictor:
def __init__(self):
self.sess = tf.Session()
saver.restore(self.sess, 'models/test_model1')
def predict(self, inputs):
return self.sess.run(preds1, feed_dict={inputs1: inputs})
predictor = Predictor()
preds = predictor.predict([[3], [7]])
print(preds) # [[ 7.3167486], [16.239666 ]]
x2 = preds # [[7], [15]]
y2 = [[7*2.+1.], [15*2.+1.]]
# train model2
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(100):
_, cost2_, preds2_ = sess.run(
[op2, cost2, preds2], feed_dict={
inputs2: x2,
labels2: y2
})
print('Cost:', cost2_)
print('Preds:', preds2_)
preds = predictor.predict([[3], [7]])
print(preds) # [[ 7.3167486], [16.239666 ]]
@Jwata
Copy link
Author

Jwata commented Jul 20, 2018

The predictor's parameters don't get initialized by the sess.run(tf.global_variables_initializer()) in another session.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment