Last active
October 2, 2017 16:36
-
-
Save kleinschmidt/87bd4aace452da66aec846f878d89e1a to your computer and use it in GitHub Desktop.
a Flux.jl example with a toy neural network that learns and, or, and xor
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Multi-task network in julia with Flux\n", | |
| "\n", | |
| "This is to illustrate how easy it is to build and train neural network models in Julia with [Flux.jl](https://github.com/FluxML/Flux.jl). Flux is a framework that gives you both a convenient, high-level API and the ability to mess with things at as low a level as you want.\n", | |
| "\n", | |
| "## Task and training data\n", | |
| "\n", | |
| "This network learns to do three different \"tasks\": binary and, or, and xor. The network gets two binary inputs plus a \"task\" signal, which is a one-hot encoding of the task variable. There's one hidden layer, and additionally the task input connect directly to the output as well." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "12-element Array{Int64,1}:\n", | |
| " 0\n", | |
| " 0\n", | |
| " 0\n", | |
| " 1\n", | |
| " 0\n", | |
| " 1\n", | |
| " 1\n", | |
| " 1\n", | |
| " 0\n", | |
| " 1\n", | |
| " 1\n", | |
| " 0" | |
| ] | |
| }, | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "using Flux\n", | |
| "\n", | |
| "# inputs for each task (columns are observations)\n", | |
| "x1 = [0 0 1 1\n", | |
| " 0 1 0 1]\n", | |
| "\n", | |
| "funcs = [&, |, xor]\n", | |
| "\n", | |
| "y = vcat((f.(x1[1,:], x1[2,:]) for f in funcs)...)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Here I'm using the julia broadcasting syntax `f.()` to apply each function `f` to the pairs of inputs (`x1[1,:]` is the first row of `x1`).\n", | |
| "\n", | |
| "We can use the `onehot` and `onehotbatch` functions to convert a vector of class labes (in this case, the task functions) into a boolean matrix using one-hot encoding:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "3×12 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}:\n", | |
| " true true true true false … false false false false false\n", | |
| " false false false false true true false false false false\n", | |
| " false false false false false false true true true true" | |
| ] | |
| }, | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "x_task = Flux.onehotbatch(repeat(funcs, inner=size(x1, 2)), funcs)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "We'll put the task inputs first so it'll be easy to pick them out later:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "5×12 Array{Int64,2}:\n", | |
| " 1 1 1 1 0 0 0 0 0 0 0 0\n", | |
| " 0 0 0 0 1 1 1 1 0 0 0 0\n", | |
| " 0 0 0 0 0 0 0 0 1 1 1 1\n", | |
| " 0 0 1 1 0 0 1 1 0 0 1 1\n", | |
| " 0 1 0 1 0 1 0 1 0 1 0 1" | |
| ] | |
| }, | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "x = vcat(x_task, repmat(x1, 1, 3))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Specifying the network\n", | |
| "\n", | |
| "First, I'll show how to specify the network with as much as Flux's high-level API as possible. It's not just a bunch of stacked layers, so we have to do a little bit of manual fiddling (which goes to show how flexible this abstraction is!).\n", | |
| "\n", | |
| "In Flux, a network is trained by specifying a loss function, which is parametrized by Flux's special \"tracked\" arrays which support automatic differentiation and training. In practice, we can usually specify a bunch of Flux layers, which are just functions that take input and generate output but with Flux parameters. We can do it manually, too, which we'll see below.\n", | |
| "\n", | |
| "To start, we'll specify the two pathways. The pathway through the hidden layer has two layers of connections (which are all densely connected). There's an output nonlinearity after the first layer. (There will also be an output non-linearity as well but we need to add the output of the two pathways together before applying it.)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "loss (generic function with 1 method)" | |
| ] | |
| }, | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "using Flux: Dense\n", | |
| "\n", | |
| "n_hidden = 10\n", | |
| "hidden = Chain(\n", | |
| " Dense(size(x, 1), n_hidden, σ),\n", | |
| " Dense(n_hidden, 1)\n", | |
| ")\n", | |
| "\n", | |
| "## direct pathway from task units to output\n", | |
| "direct = Dense(size(x_task, 1), 1)\n", | |
| "\n", | |
| "## overall output: add hidden and direct output and apply sigmoid (all elementwise)\n", | |
| "m(x::Matrix) = σ.(hidden(x) .+ direct(x[1:size(x_task,1),:]))\n", | |
| "\n", | |
| "# define the loss function we'll optimize\n", | |
| "loss(x,y) = Flux.mse(m(x), y')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "The initial loss before training, based on the randomly initialized weights:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "Tracked 0-dimensional Array{Float64,0}:\n", | |
| "0.249905" | |
| ] | |
| }, | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "loss(x,y)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Notice that this is a special \"tracked\" array. This is how Flux supports automatic backprop.\n", | |
| "\n", | |
| "## Training the network\n", | |
| "\n", | |
| "Flux provides a number of conveniences for training models as well. _Optimizers_ create a special function that, when called, optimizes the parameters using backprop. (We can also do this manually, which we'll see below)." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(::#58) (generic function with 1 method)" | |
| ] | |
| }, | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "ps = vcat(params(hidden), params(direct))\n", | |
| "η = 1\n", | |
| "opt = SGD(ps, η)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "`opt` is a function that will update the parameters we've extracted from the two layers. The `Flux.train!` function is a wrapper for the loop of calculating loss, backpropogation, and calling the optimizer (slightly annoying, the data needs to be a vector of x,y tuples, even if x is a matrix and y a vector of targets):" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "loss(x, y) = param(0.246506)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "dataset = (x,y)\n", | |
| "Flux.train!(loss, [dataset], opt, cb = () -> @show(loss(x,y)))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "`train!` runs through all the data we pass it once. We can put it in a loop to train until the tolerance reaches a certain threshold:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "loss(x, y) = param(0.243395)\n", | |
| "loss(x, y) = param(1.61531e-5)\n", | |
| " 15.496159 seconds (71.65 M allocations: 5.114 GiB, 4.17% gc time)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "using Base.Iterators: repeated\n", | |
| "\n", | |
| "tol = 1e-5\n", | |
| "stop() = loss(x,y).data[] < tol\n", | |
| "\n", | |
| "# show the value of the less function every 10 seconds\n", | |
| "callback = Flux.throttle(() -> @show(loss(x,y)), 10)\n", | |
| "\n", | |
| "dataset = repeated((x,y), 500)\n", | |
| "\n", | |
| "@time begin\n", | |
| " while !stop()\n", | |
| " ## currently there's a performance problem in the train! function that makes\n", | |
| " ## gc time dominate for small models like this, so we'll run the train loop manually\n", | |
| " #Flux.train!(loss, dataset, opt, cb = Flux.throttle(callback, 10))\n", | |
| " for d in dataset\n", | |
| " Flux.back!(loss(d...))\n", | |
| " opt()\n", | |
| " callback()\n", | |
| " end\n", | |
| " end\n", | |
| "end" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": { | |
| "collapsed": false, | |
| "scrolled": true | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "Tracked 0-dimensional Array{Float64,0}:\n", | |
| "9.98587e-6" | |
| ] | |
| }, | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "loss(x,y)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "12×2 Array{Float64,2}:\n", | |
| " 0.000630202 0.0\n", | |
| " 0.00138896 0.0\n", | |
| " 0.00138876 0.0\n", | |
| " 0.997339 1.0\n", | |
| " 0.00297521 0.0\n", | |
| " 0.998626 1.0\n", | |
| " 0.998627 1.0\n", | |
| " 0.997396 1.0\n", | |
| " 0.00236595 0.0\n", | |
| " 0.995386 1.0\n", | |
| " 0.995389 1.0\n", | |
| " 0.00639834 0.0" | |
| ] | |
| }, | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "hcat(m(x).data', y)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Low-level interface\n", | |
| "\n", | |
| "We can also specify the network using a low-level interface and still benefit from Flux's abstractions (tracked arrays, automatic backprop)." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "metadata": { | |
| "collapsed": false, | |
| "scrolled": true | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "Tracked 10-element Array{Float64,1}:\n", | |
| " 2.44412 \n", | |
| " 0.606339\n", | |
| " -3.49427 \n", | |
| " -0.467204\n", | |
| " 1.7908 \n", | |
| " -2.58228 \n", | |
| " -0.329082\n", | |
| " -2.60199 \n", | |
| " 0.448517\n", | |
| " -3.82918 " | |
| ] | |
| }, | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "hidden[1].W\n", | |
| "hidden[1].b" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "Tracked 0-dimensional Array{Float64,0}:\n", | |
| "3.15446" | |
| ] | |
| }, | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "using Flux.Tracker: param, back!, data, grad\n", | |
| "\n", | |
| "W_hidden = param(randn(n_hidden, size(x,1)))\n", | |
| "b_hidden = param(randn(n_hidden))\n", | |
| "hidden2(x) = σ.(W_hidden*x .+ b_hidden)\n", | |
| "\n", | |
| "W_out = param(randn(1, n_hidden))\n", | |
| "b_out = param(randn(1))\n", | |
| "W_direct = param(randn(1, size(x_task, 1)))\n", | |
| "predict(x::Matrix) = σ.(W_out*hidden2(x) .+ W_direct*x[1:3,:] .+ b_out)\n", | |
| "\n", | |
| "loss2(x,y) = sum((predict(x) .- y').^2)\n", | |
| "loss2(x,y)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Note how the value of `loss` is also `Tracked`. Calling `back!` on the tracked value will calculate the gradients of the individual parameters with respect to the loss:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "Tracked 0-dimensional Array{Float64,0}:\n", | |
| "3.15446" | |
| ] | |
| }, | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "l = loss2(x,y)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "10×5 Array{Float64,2}:\n", | |
| " -0.00548446 -0.0199413 0.0175428 -0.0195804 -0.0199099 \n", | |
| " 0.0214485 0.183165 -0.132873 0.122418 0.152889 \n", | |
| " -0.00695408 -0.0751547 0.0498836 -0.0555422 -0.0568866 \n", | |
| " 0.0316629 0.0468998 -0.0684927 0.0805919 0.116873 \n", | |
| " -0.00050766 0.000505941 -0.00204762 0.00315389 0.00371962\n", | |
| " 0.00411865 0.0195692 -0.0139061 0.0203315 0.0170381 \n", | |
| " 0.00433668 0.0209703 -0.014808 0.0208767 0.0201817 \n", | |
| " -0.00399932 0.0717888 -0.0267617 0.0879622 0.058855 \n", | |
| " -0.0600394 -0.197018 0.0356492 -0.218219 -0.266648 \n", | |
| " -0.00335029 -0.123585 0.119496 -0.0332793 -0.0536155 " | |
| ] | |
| }, | |
| "execution_count": 14, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "back!(loss2(x,y))\n", | |
| "grad(W_hidden)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "...and we can update `W_hidden` (using the `.data` field to access the underlying values in place) according to the gradient " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "Tracked 0-dimensional Array{Float64,0}:\n", | |
| "3.11955" | |
| ] | |
| }, | |
| "execution_count": 15, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "W_hidden.data .-= 0.1*grad(W_hidden)\n", | |
| "\n", | |
| "loss2(x,y)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "We can define our own update function (which is essentially what `Flux.SGD` does) by looping over all the parameters and updating their `.data` based on the gradient of the parameters. Then, update until the loss is below tolerance (as above):" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 16, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "loss2(x, y) = param(2.82972)\n", | |
| "loss2(x, y) = param(0.000383281)\n", | |
| "loss2(x, y) = param(0.000171005)\n", | |
| "loss2(x, y) = param(0.000109063)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "function update!(ps, η)\n", | |
| " back!(loss2(x,y))\n", | |
| " for p in ps\n", | |
| " ∇, dat = grad(p), data(p)\n", | |
| " dat .-= η .* ∇\n", | |
| " ∇ .= 0\n", | |
| " end\n", | |
| "end\n", | |
| "\n", | |
| "cb = Flux.throttle(() -> @show(loss2(x,y)), 5)\n", | |
| "\n", | |
| "ps = [W_hidden, b_hidden, W_out, b_out, W_direct]\n", | |
| "tol = 1e-4\n", | |
| "while loss2(x,y).data[] > tol\n", | |
| " back!(loss2(x,y))\n", | |
| " update!(ps, 0.1)\n", | |
| " cb()\n", | |
| "end" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 17, | |
| "metadata": { | |
| "collapsed": false, | |
| "scrolled": true | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "12×2 Array{Float64,2}:\n", | |
| " 0.0 0.0\n", | |
| " 0.0 0.0\n", | |
| " 0.0 0.0\n", | |
| " 1.0 1.0\n", | |
| " 0.0 0.0\n", | |
| " 1.0 1.0\n", | |
| " 1.0 1.0\n", | |
| " 1.0 1.0\n", | |
| " 0.0 0.0\n", | |
| " 1.0 1.0\n", | |
| " 1.0 1.0\n", | |
| " 0.01 0.0" | |
| ] | |
| }, | |
| "execution_count": 17, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "hcat(round.(predict(x).data', 2), y)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Performance\n", | |
| "\n", | |
| "To compare to the matlab implementation, we need to do a few extra things. That implementation shuffles the input data on every iteration and does one backprop step per observation (instead of in batch), and it saves the MSE throughout training. Also, there are no bias weights learned in the Matlab version, but that's not going to speed things up here." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 18, | |
| "metadata": { | |
| "collapsed": false, | |
| "scrolled": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "12-element Array{Tuple{Array{Float64,2},Float64},1}:\n", | |
| " ([1.0; 0.0; … ; 0.0; 0.0], 0.0)\n", | |
| " ([1.0; 0.0; … ; 0.0; 1.0], 0.0)\n", | |
| " ([1.0; 0.0; … ; 1.0; 0.0], 0.0)\n", | |
| " ([1.0; 0.0; … ; 1.0; 1.0], 1.0)\n", | |
| " ([0.0; 1.0; … ; 0.0; 0.0], 0.0)\n", | |
| " ([0.0; 1.0; … ; 0.0; 1.0], 1.0)\n", | |
| " ([0.0; 1.0; … ; 1.0; 0.0], 1.0)\n", | |
| " ([0.0; 1.0; … ; 1.0; 1.0], 1.0)\n", | |
| " ([0.0; 0.0; … ; 0.0; 0.0], 0.0)\n", | |
| " ([0.0; 0.0; … ; 0.0; 1.0], 1.0)\n", | |
| " ([0.0; 0.0; … ; 1.0; 0.0], 1.0)\n", | |
| " ([0.0; 0.0; … ; 1.0; 1.0], 0.0)" | |
| ] | |
| }, | |
| "execution_count": 18, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "## make a vector single-observation x,y tuples (with x as a matrix so we don't need new methods)\n", | |
| "xys = [(float(x[:,i:i]), float(y[i])) for i in 1:length(y)]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 19, | |
| "metadata": { | |
| "collapsed": false, | |
| "scrolled": true | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "12-element Array{TrackedArray{…,Array{Float64,0}},1}:\n", | |
| " param(7.2204e-7) \n", | |
| " param(1.89381e-6)\n", | |
| " param(4.1702e-6) \n", | |
| " param(1.28171e-5)\n", | |
| " param(3.99636e-6)\n", | |
| " param(2.52668e-9)\n", | |
| " param(1.21483e-8)\n", | |
| " param(6.62773e-6)\n", | |
| " param(5.34491e-6)\n", | |
| " param(1.47294e-5)\n", | |
| " param(1.62699e-5)\n", | |
| " param(3.34133e-5)" | |
| ] | |
| }, | |
| "execution_count": 19, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "[loss2(xy...) for xy in xys]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 20, | |
| "metadata": { | |
| "collapsed": false, | |
| "scrolled": true | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "loss2(x, y) = param(0.000400447)\n", | |
| " 3.513151 seconds (14.27 M allocations: 957.211 MiB, 3.84% gc time)\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "2000-element Array{Float64,1}:\n", | |
| " 2.06891 \n", | |
| " 1.83597 \n", | |
| " 1.97047 \n", | |
| " 1.81257 \n", | |
| " 1.34066 \n", | |
| " 1.29541 \n", | |
| " 1.11231 \n", | |
| " 1.05008 \n", | |
| " 0.954588 \n", | |
| " 0.939797 \n", | |
| " 0.897048 \n", | |
| " 0.726046 \n", | |
| " 0.671461 \n", | |
| " ⋮ \n", | |
| " 0.000231743\n", | |
| " 0.000231601\n", | |
| " 0.000231493\n", | |
| " 0.000231337\n", | |
| " 0.00023122 \n", | |
| " 0.000231087\n", | |
| " 0.000230968\n", | |
| " 0.000230809\n", | |
| " 0.000230671\n", | |
| " 0.000230542\n", | |
| " 0.000230429\n", | |
| " 0.000230311" | |
| ] | |
| }, | |
| "execution_count": 20, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "reset!(x) = randn!(x.data)\n", | |
| "reset!.(ps)\n", | |
| "\n", | |
| "n_training = 2000\n", | |
| "η = 0.3\n", | |
| "\n", | |
| "function train!(ps, xys, η, n_training, cb)\n", | |
| "\n", | |
| " mses = zeros(n_training)\n", | |
| "\n", | |
| " for iter in 1:n_training\n", | |
| " shuffle!(xys)\n", | |
| " for (x,y) in xys\n", | |
| " l = loss2(x,y)\n", | |
| " mses[iter] += l.data[]\n", | |
| " back!(l)\n", | |
| " update!(ps, 0.3)\n", | |
| " end\n", | |
| " cb()\n", | |
| " end\n", | |
| "\n", | |
| " return mses\n", | |
| "end\n", | |
| "\n", | |
| "@time train!(ps, xys, 0.3, 2000, cb)\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "I'm getting 3.2 seconds for 2,000 training iterations, vs. about 2.6 for the matlab code. Matlab does surprisingly well here! The tradeoff is that this code is more concise and more expressive, without sacrificing detailed control. I'd have to dig in a little and see whether it can be optimized at all." | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Profiling" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 22, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "loss2(x, y) = param(2.31568)\n", | |
| "1 ./event.jl:436; (::Base.##300#301{IJulia.#send_st...\n", | |
| " 1 ...0.6/IJulia/src/stdio.jl:88; send_stdio(::String)\n", | |
| " 1 ....6/IJulia/src/stdio.jl:130; send_stream(::String)\n", | |
| " 1 ...v0.6/IJulia/src/msg.jl:48; send_ipython(::ZMQ.Socket, ::IJul...\n", | |
| "3630 ./task.jl:335; (::IJulia.##14#17)()\n", | |
| " 3630 ...Julia/src/eventloop.jl:8; eventloop(::ZMQ.Socket)\n", | |
| " 3630 ...rc/execute_request.jl:154; execute_request(::ZMQ.Socket, ::...\n", | |
| " 3630 ...Compat/src/Compat.jl:464; include_string(::Module, ::Stri...\n", | |
| " 3630 ./<missing>:?; anonymous\n", | |
| " 3630 ./profile.jl:23; macro expansion\n", | |
| " 1 ./In[20]:13; train!(::Array{Flux.Tracker.Tr...\n", | |
| " 1258 ./In[20]:14; train!(::Array{Flux.Tracker.Tr...\n", | |
| " 1249 ./In[12]:12; loss2\n", | |
| " 929 ./In[12]:10; predict(::Array{Float64,2})\n", | |
| " 8 ./In[12]:5; hidden2\n", | |
| " 5 .../src/tracker/lib.jl:63; *(::TrackedArray{…,Array{Fl...\n", | |
| " 5 .../tracker/Tracker.jl:36; Type\n", | |
| " 1 ./linalg/matmul.jl:367; gemm_wrapper!(::Array{Float...\n", | |
| " 4 ...tracker/Tracker.jl:17; (::Flux.Tracker.Call{Base.#...\n", | |
| " 1 ./linalg/matmul.jl:146; *\n", | |
| " 1 ./linalg/matmul.jl:341; gemm_wrapper!(::Array{Float...\n", | |
| " 2 ./linalg/matmul.jl:367; gemm_wrapper!(::Array{Float...\n", | |
| " 2 ./linalg/blas.jl:1027; gemm!(::Char, ::Char, ::Fl...\n", | |
| " 2 ./abstractarray.jl:882; getindex\n", | |
| " 2 ./multidimensional.jl:438; _getindex\n", | |
| " 2 ./multidimensional.jl:442; macro expansion\n", | |
| " 1 ./multidimensional.jl:453; _unsafe_getindex(::IndexLin...\n", | |
| " 1 ...ltidimensional.jl:460; macro expansion\n", | |
| " 1 ...ltidimensional.jl:466; _unsafe_getindex!\n", | |
| " 1 ...ltidimensional.jl:472; macro expansion\n", | |
| " 1 ./cartesian.jl:64; macro expansion\n", | |
| " 1 ...tidimensional.jl:474; macro expansion\n", | |
| " 8 ./broadcast.jl:0; broadcast(::Function, ::Trac...\n", | |
| " 850 ./broadcast.jl:434; broadcast(::Function, ::Trac...\n", | |
| " 8 ./broadcast.jl:0; containertype(::TrackedArra...\n", | |
| " 2 ./broadcast.jl:34; containertype(::TrackedArra...\n", | |
| " 6 ...src/tracker/lib.jl:0; broadcast_c(::Function, ::T...\n", | |
| " 314 ...src/tracker/lib.jl:129; broadcast_c(::Function, ::T...\n", | |
| " 16 ...rc/tracker/lib.jl:97; tracked_broadcast(::Functio...\n", | |
| " 16 ./tuple.jl:181; map\n", | |
| " 7 ./array.jl:455; _collect(::Array{Float64,2}...\n", | |
| " 6 ./array.jl:375; _similar_for(::Array{Float...\n", | |
| " 2 ./abstractarray.jl:524; similar(::Array{Float64,2}...\n", | |
| " 1 ./tuple.jl:0; map(::Flux.Tracker.##8#10{3...\n", | |
| " 6 ./tuple.jl:178; map(::Flux.Tracker.##8#10{3...\n", | |
| " 4 ./array.jl:455; _collect(::Array{Float64,2...\n", | |
| " 3 ./array.jl:375; _similar_for(::Array{Float...\n", | |
| " 1 ./abstractarray.jl:524; similar(::Array{Float64,2...\n", | |
| " 1 ...rc/tracker/lib.jl:97; #8\n", | |
| " 240 ...rc/tracker/lib.jl:100; tracked_broadcast(::Functio...\n", | |
| " 236 ./broadcast.jl:434; broadcast\n", | |
| " 166 ./broadcast.jl:310; broadcast_c\n", | |
| " 3 ./reflection.jl:510; _methods_by_ftype(::Any, :...\n", | |
| " 18 ./reflection.jl:521; _methods_by_ftype(::Any, :...\n", | |
| " 1 ./broadcast.jl:312; broadcast_c\n", | |
| " 69 ./broadcast.jl:314; broadcast_c\n", | |
| " 9 ./broadcast.jl:0; broadcast_t(::Function, ::...\n", | |
| " 32 ./broadcast.jl:266; broadcast_t(::Function, ::...\n", | |
| " 2 ./boot.jl:317; Array{ForwardDiff.Dual{Voi...\n", | |
| " 8 ./broadcast.jl:268; broadcast_t(::Function, ::...\n", | |
| " 1 ./broadcast.jl:0; _broadcast!(::##11#12, ::A...\n", | |
| " 4 ./broadcast.jl:139; _broadcast!(::##11#12, ::A...\n", | |
| " 4 ./broadcast.jl:147; macro expansion\n", | |
| " 1 ./simdloop.jl:68; macro expansion\n", | |
| " 3 ./simdloop.jl:73; macro expansion\n", | |
| " 3 ./broadcast.jl:153; macro expansion\n", | |
| " 3 ./<missing>:0; (::##11#12)(::ForwardDif...\n", | |
| " 3 ...c/activation.jl:1; σ(::ForwardDiff.Dual{Vo...\n", | |
| " 2 ...ff/src/dual.jl:169; -\n", | |
| " 1 ...ff/src/dual.jl:170; exp\n", | |
| " 1 ...rc/partials.jl:82; *\n", | |
| " 1 ...c/partials.jl:109; *\n", | |
| " 1 ...c/partials.jl:199; scale_tuple\n", | |
| " 1 ...c/partials.jl:155; macro expansion\n", | |
| " 45 ...rc/tracker/lib.jl:101; tracked_broadcast(::Functio...\n", | |
| " 1 ...tracker/Tracker.jl:34; Flux.Tracker.TrackedArray(:...\n", | |
| " 11 ...src/tracker/lib.jl:91; (::Flux.Tracker.Broadcasted...\n", | |
| " 11 ./array.jl:455; _collect(::Array{ForwardDi...\n", | |
| " 6 ./array.jl:375; _similar_for(::Array{Forwa...\n", | |
| " 4 ...src/tracker/lib.jl:0; tracked_broadcast(::Functio...\n", | |
| " 41 ...src/tracker/lib.jl:97; tracked_broadcast(::Functio...\n", | |
| " 41 ./tuple.jl:178; map(::Flux.Tracker.##8#10{2...\n", | |
| " 39 ./array.jl:455; _collect(::Array{Float64,1}...\n", | |
| " 4 ./array.jl:375; _similar_for(::Array{Float6...\n", | |
| " 2 ./abstractarray.jl:524; similar(::Array{Float64,2}...\n", | |
| " 1 ./array.jl:477; collect_to!(::Array{Forward...\n", | |
| " 193 ...src/tracker/lib.jl:100; tracked_broadcast(::Functio...\n", | |
| " 191 ./broadcast.jl:434; broadcast\n", | |
| " 128 ./broadcast.jl:310; broadcast_c\n", | |
| " 1 ./reflection.jl:0; _methods_by_ftype(::Any, :...\n", | |
| " 1 ./reflection.jl:510; _methods_by_ftype(::Any, :...\n", | |
| " 15 ./reflection.jl:521; _methods_by_ftype(::Any, :...\n", | |
| " 1 ./broadcast.jl:311; broadcast_c\n", | |
| " 1 ./broadcast.jl:53; broadcast_indices\n", | |
| " 1 ./broadcast.jl:48; broadcast_indices\n", | |
| " 1 ./broadcast.jl:52; broadcast_indices\n", | |
| " 1 ./abstractarray.jl:64; indices\n", | |
| " 62 ./broadcast.jl:314; broadcast_c\n", | |
| " 11 ./broadcast.jl:0; broadcast_t(::Function, ::...\n", | |
| " 30 ./broadcast.jl:266; broadcast_t(::Function, ::...\n", | |
| " 3 ./boot.jl:317; Array{ForwardDiff.Dual{Voi...\n", | |
| " 8 ./broadcast.jl:268; broadcast_t(::Function, ::...\n", | |
| " 7 ./broadcast.jl:139; _broadcast!(::##9#10, ::Ar...\n", | |
| " 7 ./broadcast.jl:147; macro expansion\n", | |
| " 7 ./simdloop.jl:73; macro expansion\n", | |
| " 7 ./broadcast.jl:153; macro expansion\n", | |
| " 7 ./<missing>:0; (::##9#10)(::ForwardDiff....\n", | |
| " 4 ...rc/activation.jl:1; σ(::ForwardDiff.Dual{Vo...\n", | |
| " 4 ...iff/src/dual.jl:169; exp\n", | |
| " 1 ./special/exp.jl:68; exp(::Float64)\n", | |
| " 1 ./special/exp.jl:111; exp(::Float64)\n", | |
| " 1 ./special/exp.jl:112; exp(::Float64)\n", | |
| " 1 ./special/exp.jl:131; exp(::Float64)\n", | |
| " 46 ...src/tracker/lib.jl:101; tracked_broadcast(::Functio...\n", | |
| " 2 .../tracker/Tracker.jl:34; Flux.Tracker.TrackedArray(::F...\n", | |
| " 5 .../src/tracker/lib.jl:91; (::Flux.Tracker.Broadcasted{A...\n", | |
| " 1 ./array.jl:451; _collect(::Array{ForwardDif...\n", | |
| " 4 ./array.jl:455; _collect(::Array{ForwardDif...\n", | |
| " 2 ./array.jl:375; _similar_for(::Array{Forwar...\n", | |
| " 1 ./abstractarray.jl:524; similar(::Array{ForwardDif...\n", | |
| " 9 ...src/tracker/lib.jl:62; *(::TrackedArray{…,Array{F...\n", | |
| " 8 .../tracker/Tracker.jl:36; Type\n", | |
| " 8 .../tracker/Tracker.jl:17; (::Flux.Tracker.Call{Base.#*,...\n", | |
| " 2 ./linalg/matmul.jl:146; *\n", | |
| " 1 ./linalg/matmul.jl:348; gemm_wrapper!(::Array{Float...\n", | |
| " 5 ./linalg/matmul.jl:367; gemm_wrapper!(::Array{Float...\n", | |
| " 5 ./linalg/blas.jl:1027; gemm!(::Char, ::Char, ::Fl...\n", | |
| " 5 ...src/tracker/lib.jl:63; *(::TrackedArray{…,Array{F...\n", | |
| " 4 .../tracker/Tracker.jl:36; Type\n", | |
| " 4 .../tracker/Tracker.jl:17; (::Flux.Tracker.Call{Base.#*,...\n", | |
| " 1 ./linalg/matmul.jl:146; *\n", | |
| " 3 ./linalg/matmul.jl:367; gemm_wrapper!(::Array{Float...\n", | |
| " 1 ./linalg/blas.jl:0; gemm!(::Char, ::Char, ::Fl...\n", | |
| " 2 ./linalg/blas.jl:1027; gemm!(::Char, ::Char, ::Fl...\n", | |
| " 2 ./broadcast.jl:0; broadcast(::Function, ::Track...\n", | |
| " 295 ./broadcast.jl:434; broadcast(::Function, ::Track...\n", | |
| " 15 ...src/tracker/lib.jl:0; tracked_broadcast(::Function...\n", | |
| " 10 ...src/tracker/lib.jl:97; tracked_broadcast(::Function...\n", | |
| " 10 ./tuple.jl:178; map(::Flux.Tracker.##8#10{2}...\n", | |
| " 1 ./array.jl:454; _collect(::Array{Float64,2},...\n", | |
| " 1 ./generator.jl:44; next\n", | |
| " 9 ./array.jl:455; _collect(::Array{Float64,2},...\n", | |
| " 9 ./array.jl:375; _similar_for(::Array{Float6...\n", | |
| " 2 ./abstractarray.jl:524; similar(::Array{Float64,2},...\n", | |
| " 191 ...src/tracker/lib.jl:100; tracked_broadcast(::Function...\n", | |
| " 186 ./broadcast.jl:434; broadcast\n", | |
| " 128 ./broadcast.jl:310; broadcast_c\n", | |
| " 1 ./reflection.jl:510; _methods_by_ftype(::Any, ::...\n", | |
| " 3 ./reflection.jl:512; _methods_by_ftype(::Any, ::...\n", | |
| " 16 ./reflection.jl:521; _methods_by_ftype(::Any, ::...\n", | |
| " 58 ./broadcast.jl:314; broadcast_c\n", | |
| " 8 ./broadcast.jl:0; broadcast_t(::Function, ::T...\n", | |
| " 23 ./broadcast.jl:266; broadcast_t(::Function, ::T...\n", | |
| " 1 ./boot.jl:317; Array{ForwardDiff.Dual{Void...\n", | |
| " 2 ./broadcast.jl:268; broadcast_t(::Function, ::T...\n", | |
| " 1 ...src/tracker/lib.jl:88; Flux.Tracker.Broadcasted(::...\n", | |
| " 71 ...src/tracker/lib.jl:101; tracked_broadcast(::Function...\n", | |
| " 1 .../tracker/Tracker.jl:0; Flux.Tracker.TrackedArray(::F...\n", | |
| " 1 .../tracker/Tracker.jl:11; Flux.Tracker.Call{Flux.Tracke...\n", | |
| " 26 .../src/tracker/lib.jl:91; (::Flux.Tracker.Broadcasted{A...\n", | |
| " 26 ./array.jl:455; _collect(::Array{ForwardDif...\n", | |
| " 26 ./array.jl:375; _similar_for(::Array{Forwar...\n", | |
| " 23 ./abstractarray.jl:524; similar(::Array{ForwardDif...\n", | |
| " 8 .../src/tracker/lib.jl:52; sum(::TrackedArray{…,Array{...\n", | |
| " 4 ...x/src/tracker/lib.jl:4; toarray\n", | |
| " 25 ./In[20]:15; train!(::Array{Flux.Tracker.Tr...\n", | |
| " 1 ./abstractarray.jl:0; getindex(::Array{Float64,0})\n", | |
| " 1 ./abstractarray.jl:882; getindex(::Array{Float64,0})\n", | |
| " 388 ./In[20]:16; train!(::Array{Flux.Tracker.Tr...\n", | |
| " 386 ...src/tracker/back.jl:43; back!(::TrackedArray{…,Array...\n", | |
| " 45 ...src/tracker/back.jl:39; back!\n", | |
| " 45 ...src/tracker/back.jl:8; scan(::TrackedArray{…,Array{...\n", | |
| " 42 ./abstractarray.jl:1731; foreach(::Flux.Tracker.#sca...\n", | |
| " 42 ...src/tracker/back.jl:8; scan(::TrackedArray{…,Array...\n", | |
| " 35 ./abstractarray.jl:1731; foreach(::Flux.Tracker.#sc...\n", | |
| " 29 ...rc/tracker/back.jl:8; scan(::TrackedArray{…,Arr...\n", | |
| " 26 ./abstractarray.jl:1731; foreach(::Flux.Tracker.#s...\n", | |
| " 1 ...c/tracker/back.jl:6; scan(::TrackedArray{…,Ar...\n", | |
| " 24 ...c/tracker/back.jl:8; scan(::TrackedArray{…,Ar...\n", | |
| " 20 ./abstractarray.jl:1731; foreach(::Flux.Tracker.#...\n", | |
| " 1 .../tracker/back.jl:0; scan(::TrackedArray{…,A...\n", | |
| " 1 .../tracker/back.jl:6; scan(::TrackedArray{…,A...\n", | |
| " 17 .../tracker/back.jl:8; scan(::TrackedArray{…,A...\n", | |
| " 7 ./abstractarray.jl:1731; foreach(::Flux.Tracker....\n", | |
| " 7 .../tracker/back.jl:8; scan(::TrackedArray{…,...\n", | |
| " 2 ...stractarray.jl:1731; foreach(::Flux.Tracker...\n", | |
| " 1 ...tracker/back.jl:6; scan(::TrackedArray{…...\n", | |
| " 341 ...src/tracker/back.jl:24; back(::TrackedArray{…,Array...\n", | |
| " 337 ...src/tracker/back.jl:15; back(::Flux.Tracker.Call{Base...\n", | |
| " 335 ...rc/tracker/back.jl:24; back(::TrackedArray{…,Arra...\n", | |
| " 326 ...rc/tracker/back.jl:15; back(::Flux.Tracker.Call{Fl...\n", | |
| " 319 ./abstractarray.jl:1732; foreach(::Function, ::Tupl...\n", | |
| " 1 ./iterators.jl:183; next(::Base.Iterators.Zip2...\n", | |
| " 40 ./iterators.jl:185; next(::Base.Iterators.Zip2...\n", | |
| " 1 ./iterators.jl:0; zip(::Tuple{TrackedArray{…...\n", | |
| " 247 ...rc/tracker/lib.jl:117; (::Flux.Tracker.##17#19)(:...\n", | |
| " 247 ...c/tracker/back.jl:32; macro expansion\n", | |
| " 246 .../tracker/back.jl:24; back(::TrackedArray{…,A...\n", | |
| " 242 .../tracker/back.jl:15; back(::Flux.Tracker.Call...\n", | |
| " 238 ./abstractarray.jl:1732; foreach(::Function, ::T...\n", | |
| " 1 ./iterators.jl:0; done(::Base.Iterators.Z...\n", | |
| " 20 ./iterators.jl:185; next(::Base.Iterators.Z...\n", | |
| " 198 .../tracker/lib.jl:117; (::Flux.Tracker.##17#19...\n", | |
| " 198 ...tracker/back.jl:32; macro expansion\n", | |
| " 159 ...racker/back.jl:24; back(::TrackedArray{…...\n", | |
| " 155 ...racker/back.jl:15; back(::Flux.Tracker.C...\n", | |
| " 12 ...racker/lib.jl:71; back\n", | |
| " 12 ...acker/back.jl:32; macro expansion\n", | |
| " 1 ...alg/matmul.jl:0; A_mul_Bt(::Array{Flo...\n", | |
| " 7 ...alg/matmul.jl:189; A_mul_Bt(::Array{Flo...\n", | |
| " 5 ...alg/matmul.jl:191; A_mul_Bt!\n", | |
| " 1 ...lg/matmul.jl:366; gemm_wrapper!(::Arr...\n", | |
| " 3 ...lg/matmul.jl:367; gemm_wrapper!(::Arr...\n", | |
| " 2 ...alg/blas.jl:1027; gemm!(::Char, ::Ch...\n", | |
| " 1 ...alg/blas.jl:1036; gemm!(::Char, ::Ch...\n", | |
| " 1 ...acker/back.jl:21; back(::TrackedArray{...\n", | |
| " 1 ./broadcast.jl:204; broadcast!\n", | |
| " 1 ./broadcast.jl:211; broadcast_c!\n", | |
| " 1 ./broadcast.jl:139; _broadcast!\n", | |
| " 1 ./broadcast.jl:147; macro expansion\n", | |
| " 1 ./simdloop.jl:73; macro expansion\n", | |
| " 1 ./broadcast.jl:153; macro expansion\n", | |
| " 2 ...acker/back.jl:22; back(::TrackedArray{...\n", | |
| " 5 ...racker/lib.jl:71; back(::Base.#*, ::Arr...\n", | |
| " 5 ...racker/back.jl:32; macro expansion\n", | |
| " 2 ...alg/matmul.jl:189; A_mul_Bt(::Array{Flo...\n", | |
| " 1 ...alg/matmul.jl:191; A_mul_Bt!\n", | |
| " 1 ...lg/matmul.jl:367; gemm_wrapper!(::Arr...\n", | |
| " 1 ...alg/blas.jl:1027; gemm!(::Char, ::Ch...\n", | |
| " 2 ...acker/back.jl:21; back(::TrackedArray{...\n", | |
| " 2 ./broadcast.jl:204; broadcast!\n", | |
| " 1 ./broadcast.jl:208; broadcast_c!\n", | |
| " 1 ./broadcast.jl:90; check_broadcast_indices\n", | |
| " 1 ./broadcast.jl:86; check_broadcast_in...\n", | |
| " 1 ./broadcast.jl:48; broadcast_indices\n", | |
| " 1 ./broadcast.jl:52; broadcast_indices\n", | |
| " 1 ...actarray.jl:64; indices\n", | |
| " 1 ./broadcast.jl:211; broadcast_c!\n", | |
| " 1 ./broadcast.jl:139; _broadcast!\n", | |
| " 1 ./broadcast.jl:147; macro expansion\n", | |
| " 1 ./simdloop.jl:73; macro expansion\n", | |
| " 1 ./broadcast.jl:153; macro expansion\n", | |
| " 137 ...racker/lib.jl:72; back\n", | |
| " 137 ...acker/back.jl:32; macro expansion\n", | |
| " 4 ...lg/matmul.jl:182; At_mul_B(::Array{Fl...\n", | |
| " 3 ...alg/matmul.jl:184; At_mul_B!\n", | |
| " 3 ...lg/matmul.jl:367; gemm_wrapper!(::Arr...\n", | |
| " 1 ...alg/blas.jl:0; gemm!(::Char, ::Ch...\n", | |
| " 1 ...alg/blas.jl:1022; gemm!(::Char, ::Ch...\n", | |
| " 1 ...alg/blas.jl:1027; gemm!(::Char, ::Ch...\n", | |
| " 133 ...cker/back.jl:24; back(::TrackedArray...\n", | |
| " 132 ...cker/back.jl:15; back(::Flux.Tracker...\n", | |
| " 129 ...actarray.jl:1732; foreach(::Functio...\n", | |
| " 34 ./iterators.jl:185; next(::Base.Iterat...\n", | |
| " 83 ...cker/lib.jl:117; (::Flux.Tracker.##...\n", | |
| " 83 ...ker/back.jl:32; macro expansion\n", | |
| " 1 ...ker/back.jl:19; back(::TrackedArr...\n", | |
| " 1 ...ker/back.jl:21; back(::TrackedArr...\n", | |
| " 1 ./broadcast.jl:204; broadcast!\n", | |
| " 1 ...oadcast.jl:211; broadcast_c!\n", | |
| " 1 ...oadcast.jl:139; _broadcast!\n", | |
| " 1 ...adcast.jl:147; macro expansion\n", | |
| " 1 ...mdloop.jl:73; macro expansion\n", | |
| " 1 ...adcast.jl:151; macro expansion\n", | |
| " 1 ...ker/back.jl:22; back(::TrackedArr...\n", | |
| " 7 ...ker/back.jl:24; back(::TrackedArr...\n", | |
| " 4 ...ker/back.jl:15; back(::Flux.Track...\n", | |
| " 4 ...cker/lib.jl:71; back(::Base.#*, :...\n", | |
| " 4 ...er/back.jl:32; macro expansion\n", | |
| " 4 ...matmul.jl:189; A_mul_Bt(::Arra...\n", | |
| " 2 ...matmul.jl:191; A_mul_Bt!\n", | |
| " 2 ...matmul.jl:367; gemm_wrapper!(...\n", | |
| " 2 .../blas.jl:1027; gemm!(::Char...\n", | |
| " 72 ...cker/lib.jl:106; unbroadcast(::Tra...\n", | |
| " 44 ./array.jl:1819; filter(::Functi...\n", | |
| " 4 ...tarray.jl:882; getindex\n", | |
| " 4 ...nsional.jl:438; _getindex\n", | |
| " 4 ...sional.jl:442; macro expansion\n", | |
| " 4 ...sional.jl:453; _unsafe_getind...\n", | |
| " 3 ...sional.jl:458; macro expansion\n", | |
| " 1 ...sional.jl:460; macro expansion\n", | |
| " 1 ...ional.jl:466; _unsafe_getindex!\n", | |
| " 1 ...ional.jl:472; macro expansion\n", | |
| " 1 ...esian.jl:62; macro expansion\n", | |
| " 1 ...onal.jl:352; next\n", | |
| " 30 ...tarray.jl:1865; map(::Function,...\n", | |
| " 5 ./array.jl:455; _collect(::Unit...\n", | |
| " 3 ./array.jl:184; reshape(::Array...\n", | |
| " 22 ...ducedim.jl:572; sum\n", | |
| " 22 ...ducedim.jl:570; sum\n", | |
| " 22 ...ucedim.jl:241; mapreducedim\n", | |
| " 2 ...ucedim.jl:210; mapreducedim!\n", | |
| " 2 ...ucedim.jl:173; _mapreducedim!...\n", | |
| " 1 ...ucedim.jl:0; check_reducedi...\n", | |
| " 1 ...ucedim.jl:169; check_reducedi...\n", | |
| " 20 ...ucedim.jl:73; reducedim_init...\n", | |
| " 2 ./array.jl:0; fill!(::Array{...\n", | |
| " 4 ...ucedim.jl:33; reduced_indice...\n", | |
| " 1 ./array.jl:0; vect(::Base.On...\n", | |
| " 3 ./array.jl:76; vect(::Base.On...\n", | |
| " 14 ...ucedim.jl:43; reduced_indice...\n", | |
| " 3 ...cker/lib.jl:116; back\n", | |
| " 3 ...acker/lib.jl:116; (::Flux.Tracker.##...\n", | |
| " 3 ./broadcast.jl:434; broadcast\n", | |
| " 2 ./broadcast.jl:311; broadcast_c\n", | |
| " 2 ./broadcast.jl:53; broadcast_indices\n", | |
| " 2 ./tuple.jl:159; map\n", | |
| " 2 ...oadcast.jl:48; broadcast_indices\n", | |
| " 2 ...oadcast.jl:52; broadcast_indices\n", | |
| " 2 ...tarray.jl:64; indices\n", | |
| " 1 ./broadcast.jl:314; broadcast_c\n", | |
| " 1 ./broadcast.jl:266; broadcast_t\n", | |
| " 37 ...tracker/lib.jl:106; unbroadcast(::Tracked...\n", | |
| " 23 ./array.jl:1819; filter(::Function, ::...\n", | |
| " 5 ...tractarray.jl:882; getindex\n", | |
| " 5 ...imensional.jl:438; _getindex\n", | |
| " 5 ...imensional.jl:442; macro expansion\n", | |
| " 5 ...imensional.jl:453; _unsafe_getindex(::I...\n", | |
| " 4 ...mensional.jl:458; macro expansion\n", | |
| " 1 ...mensional.jl:460; macro expansion\n", | |
| " 1 ...mensional.jl:466; _unsafe_getindex!\n", | |
| " 1 ...ensional.jl:472; macro expansion\n", | |
| " 1 ./cartesian.jl:62; macro expansion\n", | |
| " 1 ...ensional.jl:352; next\n", | |
| " 11 ...tractarray.jl:1865; map(::Function, ::Un...\n", | |
| " 2 ./array.jl:455; _collect(::UnitRange{...\n", | |
| " 1 ./generator.jl:32; Base.Generator(::Flux...\n", | |
| " 13 ./reducedim.jl:572; sum\n", | |
| " 13 ./reducedim.jl:570; sum\n", | |
| " 13 ./reducedim.jl:241; mapreducedim\n", | |
| " 13 ./reducedim.jl:73; reducedim_initarray(...\n", | |
| " 5 ./reducedim.jl:33; reduced_indices(::Tu...\n", | |
| " 1 ./array.jl:0; vect(::Base.OneTo{In...\n", | |
| " 4 ./array.jl:76; vect(::Base.OneTo{In...\n", | |
| " 7 ./reducedim.jl:43; reduced_indices(::Tu...\n", | |
| " 3 .../tracker/lib.jl:116; back\n", | |
| " 1 .../tracker/lib.jl:0; (::Flux.Tracker.##16#18{...\n", | |
| " 2 .../tracker/lib.jl:116; (::Flux.Tracker.##16#18{...\n", | |
| " 2 ./broadcast.jl:434; broadcast\n", | |
| " 1 ./broadcast.jl:311; broadcast_c\n", | |
| " 1 ./broadcast.jl:53; broadcast_indices\n", | |
| " 1 ./broadcast.jl:57; broadcast_shape\n", | |
| " 1 ./broadcast.jl:57; broadcast_shape\n", | |
| " 1 ./broadcast.jl:64; _bcs\n", | |
| " 1 ./broadcast.jl:63; _bcs\n", | |
| " 1 ./broadcast.jl:314; broadcast_c\n", | |
| " 1 ./broadcast.jl:266; broadcast_t\n", | |
| " 1 .../tracker/back.jl:26; back(::TrackedArray{…,A...\n", | |
| " 6 ...rc/tracker/lib.jl:116; back\n", | |
| " 5 ...rc/tracker/lib.jl:116; (::Flux.Tracker.##16#18{Flu...\n", | |
| " 5 ./broadcast.jl:434; broadcast\n", | |
| " 5 ./broadcast.jl:314; broadcast_c\n", | |
| " 4 ./broadcast.jl:266; broadcast_t\n", | |
| " 1 ./broadcast.jl:268; broadcast_t\n", | |
| " 1 ./broadcast.jl:139; _broadcast!\n", | |
| " 1 ./broadcast.jl:147; macro expansion\n", | |
| " 1 ./simdloop.jl:73; macro expansion\n", | |
| " 1 ./broadcast.jl:153; macro expansion\n", | |
| " 1 ...src/tracker/lib.jl:55; back\n", | |
| " 1953 ./In[20]:17; train!(::Array{Flux.Tracker.Tr...\n", | |
| " 1879 ./In[16]:2; update!(::Array{Flux.Tracker.T...\n", | |
| " 1378 ./In[12]:12; loss2\n", | |
| " 1120 ./In[12]:10; predict(::Array{Int64,2})\n", | |
| " 21 ./In[12]:5; hidden2\n", | |
| " 18 ...src/tracker/lib.jl:63; *(::TrackedArray{…,Array{F...\n", | |
| " 17 ...tracker/Tracker.jl:36; Type\n", | |
| " 17 ...racker/Tracker.jl:17; (::Flux.Tracker.Call{Base.#...\n", | |
| " 2 ./linalg/matmul.jl:146; *\n", | |
| " 14 ./linalg/matmul.jl:483; generic_matmatmul!(::Array...\n", | |
| " 1 ./linalg/matmul.jl:0; _generic_matmatmul!(::Arra...\n", | |
| " 1 ./linalg/matmul.jl:490; _generic_matmatmul!(::Arra...\n", | |
| " 1 ./linalg/matmul.jl:375; lapack_size(::Char, ::Arr...\n", | |
| " 1 ./linalg/matmul.jl:508; _generic_matmatmul!(::Arra...\n", | |
| " 1 ./linalg/matmul.jl:509; _generic_matmatmul!(::Arra...\n", | |
| " 1 ./linalg/matmul.jl:515; _generic_matmatmul!(::Arra...\n", | |
| " 1 ./linalg/matmul.jl:389; copy_transpose!(::Array{F...\n", | |
| " 1 ...alg/transpose.jl:142; copy_transpose!(::Array{...\n", | |
| " 1 ./abstractarray.jl:362; checkbounds\n", | |
| " 2 ./linalg/matmul.jl:516; _generic_matmatmul!(::Arra...\n", | |
| " 1 ./linalg/matmul.jl:379; copy!(::Array{Int64,2}, :...\n", | |
| " 1 ./abstractarray.jl:716; copy!(::Array{Int64,2}, ...\n", | |
| " 1 ./linalg/matmul.jl:522; _generic_matmatmul!(::Arra...\n", | |
| " 6 ./linalg/matmul.jl:523; _generic_matmatmul!(::Arra...\n", | |
| " 9 ./abstractarray.jl:882; getindex\n", | |
| " 9 ./multidimensional.jl:438; _getindex\n", | |
| " 9 ./multidimensional.jl:442; macro expansion\n", | |
| " 9 ...ltidimensional.jl:453; _unsafe_getindex(::IndexLin...\n", | |
| " 1 ...ltidimensional.jl:456; macro expansion\n", | |
| " 1 ...ltidimensional.jl:457; macro expansion\n", | |
| " 1 ...ltidimensional.jl:311; index_shape\n", | |
| " 1 ./abstractarray.jl:64; indices\n", | |
| " 6 ...ltidimensional.jl:458; macro expansion\n", | |
| " 1 ...ltidimensional.jl:459; macro expansion\n", | |
| " 1 ./tuple.jl:284; ==(::Tuple{Int64,Int64}, :...\n", | |
| " 7 ./broadcast.jl:0; broadcast(::Function, ::Tra...\n", | |
| " 1040 ./broadcast.jl:434; broadcast(::Function, ::Tra...\n", | |
| " 7 ./broadcast.jl:0; containertype(::TrackedArra...\n", | |
| " 14 ./broadcast.jl:34; containertype(::TrackedArra...\n", | |
| " 11 ...rc/tracker/lib.jl:0; broadcast_c(::Function, ::T...\n", | |
| " 364 ...rc/tracker/lib.jl:129; broadcast_c(::Function, ::T...\n", | |
| " 1 ...rc/tracker/lib.jl:0; tracked_broadcast(::Functi...\n", | |
| " 28 ...rc/tracker/lib.jl:97; tracked_broadcast(::Functi...\n", | |
| " 28 ./tuple.jl:181; map\n", | |
| " 1 ./array.jl:454; _collect(::Array{Float64,2...\n", | |
| " 1 ./generator.jl:44; next\n", | |
| " 17 ./array.jl:455; _collect(::Array{Float64,2...\n", | |
| " 14 ./array.jl:375; _similar_for(::Array{Floa...\n", | |
| " 7 ./abstractarray.jl:524; similar(::Array{Float64,2...\n", | |
| " 1 ./array.jl:474; collect_to!(::Array{Forwa...\n", | |
| " 10 ./tuple.jl:178; map(::Flux.Tracker.##8#10{...\n", | |
| " 8 ./array.jl:455; _collect(::Array{Float64,2...\n", | |
| " 3 ./array.jl:375; _similar_for(::Array{Floa...\n", | |
| " 1 ./abstractarray.jl:524; similar(::Array{Float64,...\n", | |
| " 2 ./array.jl:474; collect_to!(::Array{Forwa...\n", | |
| " 1 ...c/tracker/lib.jl:94; (::Flux.Tracker.##6#7{Tup...\n", | |
| " 2 ./array.jl:477; collect_to!(::Array{Forwa...\n", | |
| " 2 ...rc/tracker/lib.jl:97; #8\n", | |
| " 1 ...rc/tracker/lib.jl:94; dualify\n", | |
| " 266 ...rc/tracker/lib.jl:100; tracked_broadcast(::Functi...\n", | |
| " 263 ./broadcast.jl:434; broadcast\n", | |
| " 165 ./broadcast.jl:310; broadcast_c\n", | |
| " 5 ./reflection.jl:510; _methods_by_ftype(::Any, ...\n", | |
| " 6 ./reflection.jl:512; _methods_by_ftype(::Any, ...\n", | |
| " 20 ./reflection.jl:521; _methods_by_ftype(::Any, ...\n", | |
| " 98 ./broadcast.jl:314; broadcast_c\n", | |
| " 5 ./broadcast.jl:0; broadcast_t(::Function, :...\n", | |
| " 46 ./broadcast.jl:266; broadcast_t(::Function, :...\n", | |
| " 4 ./boot.jl:317; Array{ForwardDiff.Dual{Vo...\n", | |
| " 24 ./broadcast.jl:268; broadcast_t(::Function, :...\n", | |
| " 16 ./broadcast.jl:139; _broadcast!(::##11#12, :...\n", | |
| " 16 ./broadcast.jl:147; macro expansion\n", | |
| " 16 ./simdloop.jl:73; macro expansion\n", | |
| " 15 ./broadcast.jl:153; macro expansion\n", | |
| " 15 ./<missing>:0; (::##11#12)(::ForwardDi...\n", | |
| " 1 ...ff/src/dual.jl:198; +\n", | |
| " 1 ...rc/partials.jl:117; _mul_partials\n", | |
| " 1 ...c/partials.jl:219; mul_tuples\n", | |
| " 1 ...c/partials.jl:155; macro expansion\n", | |
| " 13 .../activation.jl:1; σ(::ForwardDiff.Dual...\n", | |
| " 2 ...f/src/dual.jl:169; -\n", | |
| " 1 ...f/src/dual.jl:208; /\n", | |
| " 1 ...rc/partials.jl:82; *\n", | |
| " 1 ...c/partials.jl:109; *\n", | |
| " 1 ...c/partials.jl:199; scale_tuple\n", | |
| " 1 .../partials.jl:155; macro expansion\n", | |
| " 10 ...f/src/dual.jl:169; exp\n", | |
| " 1 ./special/exp.jl:0; exp(::Float64)\n", | |
| " 1 ./special/exp.jl:89; exp(::Float64)\n", | |
| " 8 ./special/exp.jl:112; exp(::Float64)\n", | |
| " 1 ./broadcast.jl:154; macro expansion\n", | |
| " 1 ...idimensional.jl:247; setindex!\n", | |
| " 1 ...rc/tracker/lib.jl:88; Flux.Tracker.Broadcasted(:...\n", | |
| " 49 ...rc/tracker/lib.jl:101; tracked_broadcast(::Functi...\n", | |
| " 1 ...tracker/Tracker.jl:34; Flux.Tracker.TrackedArray(:...\n", | |
| " 8 ...src/tracker/lib.jl:91; (::Flux.Tracker.Broadcasted...\n", | |
| " 8 ./array.jl:455; _collect(::Array{ForwardDi...\n", | |
| " 8 ./array.jl:375; _similar_for(::Array{Forwa...\n", | |
| " 1 ./abstractarray.jl:524; similar(::Array{ForwardDi...\n", | |
| " 8 ...rc/tracker/lib.jl:0; tracked_broadcast(::Functio...\n", | |
| " 47 ...rc/tracker/lib.jl:97; tracked_broadcast(::Functio...\n", | |
| " 1 ./tuple.jl:0; map(::Flux.Tracker.##8#10{2...\n", | |
| " 46 ./tuple.jl:178; map(::Flux.Tracker.##8#10{2...\n", | |
| " 19 ./array.jl:455; _collect(::Array{Float64,2...\n", | |
| " 13 ./array.jl:375; _similar_for(::Array{Float...\n", | |
| " 10 ./abstractarray.jl:524; similar(::Array{Float64,2...\n", | |
| " 2 ./array.jl:473; collect_to!(::Array{Forwar...\n", | |
| " 1 ./array.jl:474; collect_to!(::Array{Forwar...\n", | |
| " 1 ...rc/tracker/lib.jl:97; #8\n", | |
| " 1 ...rc/tracker/lib.jl:94; dualify\n", | |
| " 296 ...rc/tracker/lib.jl:100; tracked_broadcast(::Functio...\n", | |
| " 295 ./broadcast.jl:434; broadcast\n", | |
| " 135 ./broadcast.jl:310; broadcast_c\n", | |
| " 1 ./reflection.jl:0; _methods_by_ftype(::Any, :...\n", | |
| " 2 ./reflection.jl:510; _methods_by_ftype(::Any, :...\n", | |
| " 1 ./reflection.jl:512; _methods_by_ftype(::Any, :...\n", | |
| " 12 ./reflection.jl:521; _methods_by_ftype(::Any, :...\n", | |
| " 160 ./broadcast.jl:314; broadcast_c\n", | |
| " 11 ./broadcast.jl:0; broadcast_t(::Function, :...\n", | |
| " 32 ./broadcast.jl:266; broadcast_t(::Function, :...\n", | |
| " 5 ./boot.jl:317; Array{ForwardDiff.Dual{Voi...\n", | |
| " 107 ./broadcast.jl:268; broadcast_t(::Function, :...\n", | |
| " 1 ./broadcast.jl:0; _broadcast!(::##9#10, ::...\n", | |
| " 101 ./broadcast.jl:139; _broadcast!(::##9#10, ::...\n", | |
| " 101 ./broadcast.jl:147; macro expansion\n", | |
| " 1 ./simdloop.jl:72; macro expansion\n", | |
| " 100 ./simdloop.jl:73; macro expansion\n", | |
| " 4 ./broadcast.jl:151; macro expansion\n", | |
| " 94 ./broadcast.jl:153; macro expansion\n", | |
| " 92 ./<missing>:0; (::##9#10)(::ForwardDiff...\n", | |
| " 2 ./special/exp.jl:119; exp(::Float64)\n", | |
| " 2 ...ff/src/dual.jl:197; +\n", | |
| " 1 ...ff/src/dual.jl:198; +\n", | |
| " 1 ...rc/partials.jl:117; _mul_partials\n", | |
| " 1 ...rc/partials.jl:219; mul_tuples\n", | |
| " 1 ...c/partials.jl:155; macro expansion\n", | |
| " 2 .../activation.jl:0; σ(::ForwardDiff.Dual{...\n", | |
| " 74 .../activation.jl:1; σ(::ForwardDiff.Dual{...\n", | |
| " 1 ...ff/src/dual.jl:207; +\n", | |
| " 3 ...ff/src/dual.jl:169; -\n", | |
| " 3 ...ff/src/dual.jl:207; /\n", | |
| " 5 ...ff/src/dual.jl:208; /\n", | |
| " 5 ...rc/partials.jl:82; *\n", | |
| " 5 ...c/partials.jl:109; *\n", | |
| " 5 ...c/partials.jl:199; scale_tuple\n", | |
| " 5 ...c/partials.jl:155; macro expansion\n", | |
| " 57 ...ff/src/dual.jl:169; exp\n", | |
| " 1 ./special/exp.jl:0; exp(::Float64)\n", | |
| " 1 ./special/exp.jl:69; exp(::Float64)\n", | |
| " 2 ./special/exp.jl:74; exp(::Float64)\n", | |
| " 2 ./special/exp.jl:85; exp(::Float64)\n", | |
| " 1 ./special/exp.jl:94; exp(::Float64)\n", | |
| " 1 ./special/exp.jl:105; exp(::Float64)\n", | |
| " 2 ./special/exp.jl:110; exp(::Float64)\n", | |
| " 19 ./special/exp.jl:111; exp(::Float64)\n", | |
| " 11 ./special/exp.jl:112; exp(::Float64)\n", | |
| " 4 ./special/exp.jl:117; exp(::Float64)\n", | |
| " 3 ./special/exp.jl:118; exp(::Float64)\n", | |
| " 6 ./special/exp.jl:119; exp(::Float64)\n", | |
| " 4 ./special/exp.jl:132; exp(::Float64)\n", | |
| " 2 ...ff/src/dual.jl:170; exp\n", | |
| " 2 ...rc/partials.jl:82; *\n", | |
| " 2 ...c/partials.jl:109; *\n", | |
| " 2 ...c/partials.jl:199; scale_tuple\n", | |
| " 2 ...c/partials.jl:155; macro expansion\n", | |
| " 2 ./broadcast.jl:154; macro expansion\n", | |
| " 2 ...idimensional.jl:247; setindex!\n", | |
| " 50 ...rc/tracker/lib.jl:101; tracked_broadcast(::Functio...\n", | |
| " 1 ...tracker/Tracker.jl:34; Flux.Tracker.TrackedArray(::...\n", | |
| " 9 ...src/tracker/lib.jl:91; (::Flux.Tracker.Broadcasted{...\n", | |
| " 9 ./array.jl:455; _collect(::Array{ForwardDif...\n", | |
| " 6 ./array.jl:375; _similar_for(::Array{Forwa...\n", | |
| " 2 ./abstractarray.jl:524; similar(::Array{ForwardDif...\n", | |
| " 1 ./array.jl:473; collect_to!(::Array{Float6...\n", | |
| " 1 ./array.jl:477; collect_to!(::Array{Float6...\n", | |
| " 5 ...rc/tracker/lib.jl:62; *(::TrackedArray{…,Array{...\n", | |
| " 3 .../tracker/Tracker.jl:36; Type\n", | |
| " 1 ...tracker/Tracker.jl:0; (::Flux.Tracker.Call{Base.#*...\n", | |
| " 2 ...tracker/Tracker.jl:17; (::Flux.Tracker.Call{Base.#*...\n", | |
| " 2 ./linalg/matmul.jl:367; gemm_wrapper!(::Array{Float...\n", | |
| " 2 ./linalg/blas.jl:1027; gemm!(::Char, ::Char, ::Fl...\n", | |
| " 11 ...rc/tracker/lib.jl:63; *(::TrackedArray{…,Array{...\n", | |
| " 10 ...tracker/Tracker.jl:36; Type\n", | |
| " 10 ...tracker/Tracker.jl:17; (::Flux.Tracker.Call{Base.#...\n", | |
| " 1 ./linalg/matmul.jl:146; *\n", | |
| " 9 ./linalg/matmul.jl:483; generic_matmatmul!(::Array{...\n", | |
| " 1 ./linalg/matmul.jl:494; _generic_matmatmul!(::Arra...\n", | |
| " 1 ./linalg/matmul.jl:508; _generic_matmatmul!(::Arra...\n", | |
| " 1 ./linalg/matmul.jl:509; _generic_matmatmul!(::Arra...\n", | |
| " 1 ./linalg/matmul.jl:515; _generic_matmatmul!(::Arra...\n", | |
| " 1 ./linalg/matmul.jl:389; copy_transpose!(::Array{Fl...\n", | |
| " 1 ...alg/transpose.jl:143; copy_transpose!(::Array{F...\n", | |
| " 1 ./abstractarray.jl:362; checkbounds\n", | |
| " 4 ./linalg/matmul.jl:516; _generic_matmatmul!(::Arra...\n", | |
| " 4 ./linalg/matmul.jl:379; copy!(::Array{Int64,2}, ::...\n", | |
| " 1 ./abstractarray.jl:706; copy!(::Array{Int64,2}, :...\n", | |
| " 1 ./range.jl:393; length\n", | |
| " 1 ./checked.jl:221; checked_sub\n", | |
| " 2 ./abstractarray.jl:715; copy!(::Array{Int64,2}, :...\n", | |
| " 1 ./abstractarray.jl:716; copy!(::Array{Int64,2}, :...\n", | |
| " 1 ./linalg/matmul.jl:581; _generic_matmatmul!(::Arra...\n", | |
| " 239 ./broadcast.jl:434; broadcast(::Function, ::Tra...\n", | |
| " 8 ...src/tracker/lib.jl:0; tracked_broadcast(::Functio...\n", | |
| " 14 ...src/tracker/lib.jl:97; tracked_broadcast(::Functio...\n", | |
| " 14 ./tuple.jl:178; map(::Flux.Tracker.##8#10{2...\n", | |
| " 12 ./array.jl:455; _collect(::Array{Float64,2}...\n", | |
| " 10 ./array.jl:375; _similar_for(::Array{Float...\n", | |
| " 9 ./abstractarray.jl:524; similar(::Array{Float64,2}...\n", | |
| " 2 ./array.jl:473; collect_to!(::Array{Forwar...\n", | |
| " 1 ...rc/tracker/lib.jl:97; #8\n", | |
| " 170 ...src/tracker/lib.jl:100; tracked_broadcast(::Functio...\n", | |
| " 168 ./broadcast.jl:434; broadcast\n", | |
| " 98 ./broadcast.jl:310; broadcast_c\n", | |
| " 3 ./reflection.jl:510; _methods_by_ftype(::Any, :...\n", | |
| " 1 ./reflection.jl:512; _methods_by_ftype(::Any, :...\n", | |
| " 11 ./reflection.jl:521; _methods_by_ftype(::Any, :...\n", | |
| " 2 ./broadcast.jl:311; broadcast_c\n", | |
| " 2 ./broadcast.jl:53; broadcast_indices\n", | |
| " 2 ./tuple.jl:158; map\n", | |
| " 2 ./broadcast.jl:48; broadcast_indices\n", | |
| " 2 ./broadcast.jl:52; broadcast_indices\n", | |
| " 2 ...alg/rowvector.jl:112; indices\n", | |
| " 2 ./abstractarray.jl:64; indices\n", | |
| " 68 ./broadcast.jl:314; broadcast_c\n", | |
| " 8 ./broadcast.jl:0; broadcast_t(::Function, ::...\n", | |
| " 32 ./broadcast.jl:266; broadcast_t(::Function, ::...\n", | |
| " 1 ./boot.jl:317; Array{ForwardDiff.Dual{Voi...\n", | |
| " 9 ./broadcast.jl:268; broadcast_t(::Function, ::...\n", | |
| " 2 ./broadcast.jl:139; _broadcast!(::##13#14, ::A...\n", | |
| " 2 ./broadcast.jl:147; macro expansion\n", | |
| " 2 ./simdloop.jl:73; macro expansion\n", | |
| " 2 ./broadcast.jl:153; macro expansion\n", | |
| " 2 ./<missing>:0; (::##13#14)(::ForwardDiff...\n", | |
| " 2 ./intfuncs.jl:205; literal_pow\n", | |
| " 1 ...iff/src/dual.jl:362; ^\n", | |
| " 1 ...iff/src/dual.jl:363; ^\n", | |
| " 1 ...src/partials.jl:82; *\n", | |
| " 1 ...rc/partials.jl:109; *\n", | |
| " 1 ...c/partials.jl:199; scale_tuple\n", | |
| " 1 ...c/partials.jl:155; macro expansion\n", | |
| " 39 ...src/tracker/lib.jl:101; tracked_broadcast(::Functio...\n", | |
| " 1 .../tracker/Tracker.jl:0; Flux.Tracker.TrackedArray(::F...\n", | |
| " 2 .../tracker/Tracker.jl:34; Flux.Tracker.TrackedArray(::F...\n", | |
| " 7 .../src/tracker/lib.jl:91; (::Flux.Tracker.Broadcasted{A...\n", | |
| " 7 ./array.jl:455; _collect(::Array{ForwardDif...\n", | |
| " 6 ./array.jl:375; _similar_for(::Array{Forwar...\n", | |
| " 1 ./abstractarray.jl:524; similar(::Array{ForwardDif...\n", | |
| " 1 ./reduce.jl:276; _mapreduce(::Base.#identity,...\n", | |
| " 497 ...rc/tracker/back.jl:43; back!(::TrackedArray{…,Arr...\n", | |
| " 41 ...src/tracker/back.jl:39; back!\n", | |
| " 41 ...src/tracker/back.jl:8; scan(::TrackedArray{…,Array...\n", | |
| " 39 ./abstractarray.jl:1731; foreach(::Flux.Tracker.#sca...\n", | |
| " 39 ...rc/tracker/back.jl:8; scan(::TrackedArray{…,Arra...\n", | |
| " 33 ./abstractarray.jl:1731; foreach(::Flux.Tracker.#sc...\n", | |
| " 32 ...c/tracker/back.jl:8; scan(::TrackedArray{…,Arr...\n", | |
| " 31 ./abstractarray.jl:1731; foreach(::Flux.Tracker.#...\n", | |
| " 1 ...c/tracker/back.jl:6; scan(::TrackedArray{…,Ar...\n", | |
| " 30 ...c/tracker/back.jl:8; scan(::TrackedArray{…,Ar...\n", | |
| " 22 ./abstractarray.jl:1731; foreach(::Flux.Tracker....\n", | |
| " 22 .../tracker/back.jl:8; scan(::TrackedArray{…,...\n", | |
| " 13 ...stractarray.jl:1731; foreach(::Flux.Tracker...\n", | |
| " 1 ...tracker/back.jl:0; scan(::TrackedArray{…...\n", | |
| " 1 ...tracker/back.jl:6; scan(::TrackedArray{…...\n", | |
| " 10 ...tracker/back.jl:8; scan(::TrackedArray{…...\n", | |
| " 3 ...stractarray.jl:1731; foreach(::Flux.Tracke...\n", | |
| " 1 ...tracker/back.jl:8; scan(::TrackedArray{…...\n", | |
| " 1 ...tracker/back.jl:12; scan(::TrackedArray{…...\n", | |
| " 456 ...src/tracker/back.jl:24; back(::TrackedArray{…,Array...\n", | |
| " 453 ...rc/tracker/back.jl:15; back(::Flux.Tracker.Call{Bas...\n", | |
| " 446 ...rc/tracker/back.jl:24; back(::TrackedArray{…,Arr...\n", | |
| " 435 ...c/tracker/back.jl:15; back(::Flux.Tracker.Call{Fl...\n", | |
| " 410 ./abstractarray.jl:1732; foreach(::Function, ::Tup...\n", | |
| " 26 ./iterators.jl:185; next(::Base.Iterators.Zip...\n", | |
| " 1 ./iterators.jl:0; zip(::Tuple{TrackedArray{...\n", | |
| " 361 ...c/tracker/lib.jl:117; (::Flux.Tracker.##17#19)(...\n", | |
| " 361 .../tracker/back.jl:32; macro expansion\n", | |
| " 1 ...tracker/back.jl:19; back(::TrackedArray{…,...\n", | |
| " 1 ./refpointer.jl:121; setindex!\n", | |
| " 359 ...tracker/back.jl:24; back(::TrackedArray{…,...\n", | |
| " 357 ...tracker/back.jl:15; back(::Flux.Tracker.Call...\n", | |
| " 345 ...stractarray.jl:1732; foreach(::Function, ::...\n", | |
| " 32 ./iterators.jl:185; next(::Base.Iterators....\n", | |
| " 294 ...tracker/lib.jl:117; (::Flux.Tracker.##17#1...\n", | |
| " 294 ...racker/back.jl:32; macro expansion\n", | |
| " 1 ...acker/back.jl:22; back(::TrackedArray{…...\n", | |
| " 245 ...acker/back.jl:24; back(::TrackedArray{…...\n", | |
| " 238 ...acker/back.jl:15; back(::Flux.Tracker.C...\n", | |
| " 13 ...racker/lib.jl:71; back\n", | |
| " 13 ...acker/back.jl:32; macro expansion\n", | |
| " 1 ...alg/matmul.jl:0; A_mul_Bt(::Array{Flo...\n", | |
| " 8 ...alg/matmul.jl:189; A_mul_Bt(::Array{Flo...\n", | |
| " 8 ...lg/matmul.jl:191; A_mul_Bt!\n", | |
| " 1 ...lg/matmul.jl:366; gemm_wrapper!(::Ar...\n", | |
| " 7 ...lg/matmul.jl:367; gemm_wrapper!(::Ar...\n", | |
| " 6 ...alg/blas.jl:1027; gemm!(::Char, ::C...\n", | |
| " 2 ...acker/back.jl:21; back(::TrackedArray{...\n", | |
| " 2 ./broadcast.jl:204; broadcast!\n", | |
| " 1 ./broadcast.jl:208; broadcast_c!\n", | |
| " 1 ./broadcast.jl:211; broadcast_c!\n", | |
| " 1 ./broadcast.jl:139; _broadcast!\n", | |
| " 1 ./broadcast.jl:147; macro expansion\n", | |
| " 1 ./simdloop.jl:66; macro expansion\n", | |
| " 2 ...acker/back.jl:22; back(::TrackedArray{...\n", | |
| " 17 ...racker/lib.jl:71; back(::Base.#*, ::Ar...\n", | |
| " 17 ...acker/back.jl:32; macro expansion\n", | |
| " 13 ...lg/matmul.jl:483; generic_matmatmul!(...\n", | |
| " 5 ...lg/matmul.jl:508; _generic_matmatmul!...\n", | |
| " 3 ...lg/matmul.jl:509; _generic_matmatmul!...\n", | |
| " 2 ...lg/matmul.jl:515; _generic_matmatmul!...\n", | |
| " 2 ...lg/matmul.jl:389; copy_transpose!(::...\n", | |
| " 2 ...ranspose.jl:148; copy_transpose!(::...\n", | |
| " 3 ...lg/matmul.jl:516; _generic_matmatmul!...\n", | |
| " 3 ...lg/matmul.jl:381; copy!(::Array{Int6...\n", | |
| " 1 ...ranspose.jl:143; copy_transpose!(::...\n", | |
| " 1 ...actarray.jl:362; checkbounds\n", | |
| " 1 ...ranspose.jl:147; copy_transpose!(::...\n", | |
| " 1 ...ranspose.jl:148; copy_transpose!(::...\n", | |
| " 1 ...cker/back.jl:0; back(::TrackedArray...\n", | |
| " 1 ...cker/back.jl:19; back(::TrackedArray...\n", | |
| " 1 ...cker/back.jl:21; back(::TrackedArray...\n", | |
| " 207 ...racker/lib.jl:72; back\n", | |
| " 207 ...acker/back.jl:32; macro expansion\n", | |
| " 4 ...lg/matmul.jl:182; At_mul_B(::Array{F...\n", | |
| " 4 ...lg/matmul.jl:184; At_mul_B!\n", | |
| " 4 ...lg/matmul.jl:367; gemm_wrapper!(::Ar...\n", | |
| " 1 ...alg/blas.jl:0; gemm!(::Char, ::C...\n", | |
| " 3 ...alg/blas.jl:1027; gemm!(::Char, ::C...\n", | |
| " 203 ...cker/back.jl:24; back(::TrackedArra...\n", | |
| " 199 ...cker/back.jl:15; back(::Flux.Tracke...\n", | |
| " 178 ...actarray.jl:1732; foreach(::Functio...\n", | |
| " 1 ./iterators.jl:187; done(::Base.Itera...\n", | |
| " 22 ./iterators.jl:185; next(::Base.Itera...\n", | |
| " 132 ...cker/lib.jl:117; (::Flux.Tracker.#...\n", | |
| " 132 ...ker/back.jl:32; macro expansion\n", | |
| " 1 ...er/back.jl:21; back(::TrackedAr...\n", | |
| " 1 ...oadcast.jl:204; broadcast!\n", | |
| " 1 ...oadcast.jl:211; broadcast_c!\n", | |
| " 1 ...adcast.jl:139; _broadcast!\n", | |
| " 1 ...adcast.jl:147; macro expansion\n", | |
| " 1 ...mdloop.jl:73; macro expansion\n", | |
| " 1 ...dcast.jl:151; macro expansion\n", | |
| " 2 ...er/back.jl:22; back(::TrackedAr...\n", | |
| " 48 ...er/back.jl:24; back(::TrackedAr...\n", | |
| " 45 ...er/back.jl:15; back(::Flux.Trac...\n", | |
| " 43 ...ker/lib.jl:71; back(::Base.#*,...\n", | |
| " 43 ...r/back.jl:32; macro expansion\n", | |
| " 1 ...matmul.jl:189; A_mul_Bt\n", | |
| " 1 ...matmul.jl:473; generic_matmat...\n", | |
| " 40 ...matmul.jl:483; generic_matmat...\n", | |
| " 4 ...atmul.jl:509; _generic_matm...\n", | |
| " 5 ...atmul.jl:515; _generic_matm...\n", | |
| " 5 ...atmul.jl:389; copy_transpos...\n", | |
| " 1 ...spose.jl:147; copy_transpo...\n", | |
| " 4 ...spose.jl:148; copy_transpo...\n", | |
| " 4 ...atmul.jl:516; _generic_matm...\n", | |
| " 3 ...atmul.jl:381; copy!(::Array...\n", | |
| " 1 ...spose.jl:134; copy_transpo...\n", | |
| " 1 ./range.jl:393; length\n", | |
| " 1 ...cked.jl:164; checked_add\n", | |
| " 1 ...spose.jl:146; copy_transpo...\n", | |
| " 1 ...spose.jl:147; copy_transpo...\n", | |
| " 1 ...atmul.jl:519; _generic_matm...\n", | |
| " 4 ...atmul.jl:522; _generic_matm...\n", | |
| " 17 ...atmul.jl:523; _generic_matm...\n", | |
| " 2 ...atmul.jl:525; _generic_matm...\n", | |
| " 1 ...atmul.jl:581; _generic_matm...\n", | |
| " 1 ...r/back.jl:21; back(::Tracked...\n", | |
| " 1 ...adcast.jl:204; broadcast!\n", | |
| " 1 ...dcast.jl:207; broadcast_c!\n", | |
| " 1 ...array.jl:64; indices\n", | |
| " 80 ...ker/lib.jl:106; unbroadcast(::Tr...\n", | |
| " 49 ./array.jl:1819; filter(::Functi...\n", | |
| " 2 ...tarray.jl:882; getindex\n", | |
| " 2 ...sional.jl:438; _getindex\n", | |
| " 2 ...sional.jl:442; macro expansion\n", | |
| " 1 ...sional.jl:453; _unsafe_getind...\n", | |
| " 1 ...ional.jl:458; macro expansion\n", | |
| " 33 ...tarray.jl:1865; map(::Function...\n", | |
| " 4 ./array.jl:455; _collect(::Unit...\n", | |
| " 1 ./array.jl:0; reshape(::Array...\n", | |
| " 1 ./array.jl:184; reshape(::Array...\n", | |
| " 25 ...ucedim.jl:572; sum\n", | |
| " 25 ...ucedim.jl:570; sum\n", | |
| " 25 ...ucedim.jl:241; mapreducedim\n", | |
| " 2 ...ucedim.jl:210; mapreducedim!\n", | |
| " 2 ...ucedim.jl:202; _mapreducedim!...\n", | |
| " 2 ...mdloop.jl:71; macro expansion\n", | |
| " 23 ...ucedim.jl:73; reducedim_init...\n", | |
| " 3 ...ucedim.jl:33; reduced_indice...\n", | |
| " 1 ./array.jl:0; vect(::Base.On...\n", | |
| " 1 ./array.jl:76; vect(::Base.On...\n", | |
| " 16 ...ucedim.jl:43; reduced_indice...\n", | |
| " 20 ...cker/lib.jl:116; back\n", | |
| " 1 ...cker/lib.jl:0; (::Flux.Tracker.#...\n", | |
| " 19 ...cker/lib.jl:116; (::Flux.Tracker.#...\n", | |
| " 19 ./broadcast.jl:434; broadcast\n", | |
| " 2 ...oadcast.jl:311; broadcast_c\n", | |
| " 2 ./broadcast.jl:53; broadcast_indices\n", | |
| " 1 ...oadcast.jl:48; broadcast_indices\n", | |
| " 1 ...oadcast.jl:52; broadcast_indices\n", | |
| " 1 ...tarray.jl:64; indices\n", | |
| " 1 ...oadcast.jl:57; broadcast_shape\n", | |
| " 1 ...oadcast.jl:57; broadcast_shape\n", | |
| " 1 ...adcast.jl:63; _bcs\n", | |
| " 1 ...adcast.jl:0; _bcs1(::Base.On...\n", | |
| " 17 ...oadcast.jl:314; broadcast_c\n", | |
| " 8 ...oadcast.jl:266; broadcast_t\n", | |
| " 9 ...oadcast.jl:268; broadcast_t\n", | |
| " 9 ...oadcast.jl:139; _broadcast!\n", | |
| " 9 ...adcast.jl:147; macro expansion\n", | |
| " 9 ...mdloop.jl:73; macro expansion\n", | |
| " 8 ...adcast.jl:151; macro expansion\n", | |
| " 1 ...adcast.jl:154; macro expansion\n", | |
| " 1 ...ional.jl:247; setindex!\n", | |
| " 47 ...racker/lib.jl:106; unbroadcast(::Tracked...\n", | |
| " 23 ./array.jl:1819; filter(::Function, :...\n", | |
| " 2 ...tractarray.jl:882; getindex\n", | |
| " 2 ...imensional.jl:438; _getindex\n", | |
| " 1 ...imensional.jl:441; macro expansion\n", | |
| " 1 ...ractarray.jl:362; checkbounds\n", | |
| " 1 ...imensional.jl:442; macro expansion\n", | |
| " 1 ...mensional.jl:453; _unsafe_getindex(::...\n", | |
| " 1 ...mensional.jl:458; macro expansion\n", | |
| " 18 ...tractarray.jl:1865; map(::Function, ::Un...\n", | |
| " 2 ./array.jl:455; _collect(::UnitRange...\n", | |
| " 1 ./array.jl:184; reshape(::Array{Floa...\n", | |
| " 22 ./reducedim.jl:572; sum\n", | |
| " 22 ./reducedim.jl:570; sum\n", | |
| " 22 ./reducedim.jl:241; mapreducedim\n", | |
| " 2 ./reducedim.jl:210; mapreducedim!\n", | |
| " 2 ./reducedim.jl:194; _mapreducedim!(::Ba...\n", | |
| " 2 ./simdloop.jl:73; macro expansion\n", | |
| " 2 ./reducedim.jl:195; macro expansion\n", | |
| " 20 ./reducedim.jl:73; reducedim_initarray...\n", | |
| " 1 ./reducedim.jl:0; reduced_indices(::T...\n", | |
| " 1 ./reducedim.jl:33; reduced_indices(::T...\n", | |
| " 1 ./array.jl:76; vect(::Base.OneTo{I...\n", | |
| " 17 ./reducedim.jl:43; reduced_indices(::T...\n", | |
| " 12 ...tracker/lib.jl:116; back\n", | |
| " 11 .../tracker/lib.jl:116; (::Flux.Tracker.##16#18...\n", | |
| " 11 ./broadcast.jl:434; broadcast\n", | |
| " 1 ./broadcast.jl:311; broadcast_c\n", | |
| " 1 ./broadcast.jl:53; broadcast_indices\n", | |
| " 1 ./tuple.jl:159; map\n", | |
| " 1 ./broadcast.jl:48; broadcast_indices\n", | |
| " 1 ./broadcast.jl:52; broadcast_indices\n", | |
| " 1 ...tractarray.jl:64; indices\n", | |
| " 10 ./broadcast.jl:314; broadcast_c\n", | |
| " 8 ./broadcast.jl:266; broadcast_t\n", | |
| " 2 ./broadcast.jl:268; broadcast_t\n", | |
| " 2 ./broadcast.jl:139; _broadcast!\n", | |
| " 2 ./broadcast.jl:147; macro expansion\n", | |
| " 2 ./simdloop.jl:73; macro expansion\n", | |
| " 2 ./broadcast.jl:153; macro expansion\n", | |
| " 1 .../tracker/lib.jl:106; unbroadcast(::TrackedArr...\n", | |
| " 21 ...c/tracker/lib.jl:116; back\n", | |
| " 1 ...rc/tracker/lib.jl:0; (::Flux.Tracker.##16#18{Fl...\n", | |
| " 19 ...rc/tracker/lib.jl:116; (::Flux.Tracker.##16#18{Fl...\n", | |
| " 19 ./broadcast.jl:434; broadcast\n", | |
| " 19 ./broadcast.jl:314; broadcast_c\n", | |
| " 17 ./broadcast.jl:266; broadcast_t\n", | |
| " 2 ./broadcast.jl:268; broadcast_t\n", | |
| " 2 ./broadcast.jl:139; _broadcast!\n", | |
| " 2 ./broadcast.jl:147; macro expansion\n", | |
| " 2 ./simdloop.jl:73; macro expansion\n", | |
| " 2 ./broadcast.jl:153; macro expansion\n", | |
| " 7 ...src/tracker/lib.jl:55; back\n", | |
| " 1 ...src/tracker/lib.jl:52; sum(::TrackedArray{…,Array...\n", | |
| " 1 ...x/src/tracker/lib.jl:4; toarray\n", | |
| " 2 ./In[16]:3; update!(::Array{Flux.Tracker.T...\n", | |
| " 6 ./In[16]:4; update!(::Array{Flux.Tracker.T...\n", | |
| " 57 ./In[16]:5; update!(::Array{Flux.Tracker.T...\n", | |
| " 29 ./broadcast.jl:204; broadcast!(::Function, ::Arra...\n", | |
| " 1 ./broadcast.jl:208; broadcast_c!\n", | |
| " 1 ./broadcast.jl:89; check_broadcast_indices\n", | |
| " 1 ./broadcast.jl:84; check_broadcast_shape(::Tuple...\n", | |
| " 1 ./broadcast.jl:0; check_broadcast_shape(::Tuple...\n", | |
| " 26 ./broadcast.jl:211; broadcast_c!\n", | |
| " 6 ./broadcast.jl:139; _broadcast!(::##15#16, ::Arra...\n", | |
| " 2 ./broadcast.jl:144; macro expansion\n", | |
| " 4 ./broadcast.jl:147; macro expansion\n", | |
| " 1 ./simdloop.jl:66; macro expansion\n", | |
| " 2 ./simdloop.jl:72; macro expansion\n", | |
| " 1 ./simdloop.jl:73; macro expansion\n", | |
| " 1 ./broadcast.jl:154; macro expansion\n", | |
| " 1 ...ltidimensional.jl:247; setindex!\n", | |
| " 8 ./In[16]:6; update!(::Array{Flux.Tracker.T...\n", | |
| " 2 ./broadcast.jl:22; broadcast!(::Base.#identity, :...\n", | |
| " 1 ./array.jl:227; fill!(::Array{Float64,2}, ::I...\n", | |
| " 1 ./array.jl:228; fill!(::Array{Float64,1}, ::I...\n", | |
| " 5 ./In[20]:19; train!(::Array{Flux.Tracker.Tr...\n", | |
| " 4 ....6/Flux/src/utils.jl:37; (::Flux.##throttled#4#9{Bool,Bo...\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "reset!.(ps)\n", | |
| "@profile train!(ps, xys, 0.3, 2000, cb)\n", | |
| "Profile.print()" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Julia 0.6.0", | |
| "language": "julia", | |
| "name": "julia-0.6" | |
| }, | |
| "language_info": { | |
| "file_extension": ".jl", | |
| "mimetype": "application/julia", | |
| "name": "julia", | |
| "version": "0.6.0" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment