{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Fitting B time series (animation)\n\nThis example demonstrates how to fit generic\nB observation inputs and fit an SECS system\nto make predictions on a separate grid and\ncompare the results.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\nfrom matplotlib.animation import FuncAnimation\nimport numpy as np\n\nfrom pysecs import SECS\n\n\nR_earth = 6371e3\n\n# specify the SECS grid\nlat, lon, r = np.meshgrid(np.linspace(-20, 20, 30),\n                          np.linspace(-20, 20, 30),\n                          R_earth + 110000, indexing='ij')\nsecs_lat_lon_r = np.hstack((lat.reshape(-1, 1),\n                            lon.reshape(-1, 1),\n                            r.reshape(-1, 1)))\n\n# Set up the class\nsecs = SECS(sec_df_loc=secs_lat_lon_r)\n\n# Make a grid of input observations spanning\n# (-10, 10) in latitutde and longitude\nlat, lon, r = np.meshgrid(np.linspace(-10, 10, 11),\n                          np.linspace(-10, 10, 11),\n                          R_earth, indexing='ij')\nobs_lat = lat[..., 0]\nobs_lon = lon[..., 0]\nobs_lat_lon_r = np.hstack((lat.reshape(-1, 1),\n                           lon.reshape(-1, 1),\n                           r.reshape(-1, 1)))\nnobs = len(obs_lat_lon_r)\n\n# Create the synthetic magnetic field data as a function\n# of time\nts = np.linspace(0, 2*np.pi)\nbx = 5*np.cos(ts)\nby = 5*np.sin(ts)\nbz = ts\n# ntimes x 3\nB_obs = np.column_stack([bx, by, bz])\nntimes = len(B_obs)\n\n# Repeat that for each observatory\n# ntimes x nobs x 3\nB_obs = np.repeat(B_obs[:, np.newaxis, :], nobs, axis=1)\n# Make it more interesting and add a sin wave in spatial\n# coordinates too\nB_obs[:, :, 0] *= 2*np.sin(np.deg2rad(obs_lat_lon_r[:, 0]))\nB_obs[:, :, 1] *= 2*np.sin(np.deg2rad(obs_lat_lon_r[:, 1]))\n\nB_std = np.ones(B_obs.shape)\n# Ignore the Z component\nB_std[..., 2] = np.inf\n# Can modify the standard error as a function of time to\n# see how that changes the fits too\n# B_std[:, 0, 1] = 1 + ts\n\n# Fit the data, requires observation locations and data\nsecs.fit(obs_loc=obs_lat_lon_r, obs_B=B_obs, obs_std=B_std)\n\n# Create prediction points\n# Extend it a little beyond the observation points (-11, 11)\nlat, lon, r = np.meshgrid(np.linspace(-11, 11, 11),\n                          np.linspace(-11, 11, 11),\n                          R_earth, indexing='ij')\npred_lat_lon_r = np.hstack((lat.reshape(-1, 1),\n                            lon.reshape(-1, 1),\n                            r.reshape(-1, 1)))\n\n# Call the prediction function\nB_pred = secs.predict(pred_lat_lon_r)\n\n# Now set up the plots\nfig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2,\n                                             sharex=True, sharey=True)\nvmin, vmax = -5, 5\ncmap = plt.get_cmap('RdBu_r')\nt = 10\n\nmesh1 = ax1.pcolormesh(obs_lon, obs_lat, B_obs[t, :, 0].reshape(obs_lon.shape),\n                       vmin=vmin, vmax=vmax, cmap=cmap, shading='auto')\nmesh2 = ax2.pcolormesh(obs_lon, obs_lat, B_obs[t, :, 1].reshape(obs_lon.shape),\n                       vmin=vmin, vmax=vmax, cmap=cmap, shading='auto')\nax1.set_ylabel(\"Observations\")\nax1.set_title(\"B$_{X}$\")\nax2.set_title(\"B$_{Y}$\")\n\nqscale = 1\nq1 = ax1.quiver(obs_lat_lon_r[:, 1], obs_lat_lon_r[:, 0],\n                B_obs[t, :, 1], B_obs[t, :, 0],\n                angles='xy', scale_units='xy', scale=qscale)\nq2 = ax2.quiver(obs_lat_lon_r[:, 1], obs_lat_lon_r[:, 0],\n                B_obs[t, :, 1], B_obs[t, :, 0],\n                angles='xy', scale_units='xy', scale=qscale)\n\nlon = lon[..., 0]\nlat = lat[..., 0]\nmesh3 = ax3.pcolormesh(lon, lat, B_pred[t, :, 0].reshape(lon.shape),\n                       vmin=vmin, vmax=vmax, cmap=cmap, shading='auto')\nmesh4 = ax4.pcolormesh(lon, lat, B_pred[t, :, 1].reshape(lon.shape),\n                       vmin=vmin, vmax=vmax, cmap=cmap, shading='auto')\n\nq3 = ax3.quiver(pred_lat_lon_r[:, 1], pred_lat_lon_r[:, 0],\n                B_pred[t, :, 1], B_pred[t, :, 0],\n                angles='xy', scale_units='xy', scale=qscale)\nq4 = ax4.quiver(pred_lat_lon_r[:, 1], pred_lat_lon_r[:, 0],\n                B_pred[t, :, 1], B_pred[t, :, 0],\n                angles='xy', scale_units='xy', scale=qscale)\nax3.set_ylabel(\"Predictions\")\nax3.set_title(\"B$_{X}$\")\nax4.set_title(\"B$_{Y}$\")\n\n\ndef update_axes(t):\n    # Update the mesh colors\n    mesh1.set_array(B_obs[t, :, 0].reshape(obs_lon.shape))\n    mesh2.set_array(B_obs[t, :, 1].reshape(obs_lon.shape))\n    mesh3.set_array(B_pred[t, :, 0].reshape(lon.shape))\n    mesh4.set_array(B_pred[t, :, 1].reshape(lon.shape))\n\n    # Update the quiver arrows\n    q1.set_UVC(B_obs[t, :, 1], B_obs[t, :, 0])\n    q2.set_UVC(B_obs[t, :, 1], B_obs[t, :, 0])\n    q3.set_UVC(B_pred[t, :, 1], B_pred[t, :, 0])\n    q4.set_UVC(B_pred[t, :, 1], B_pred[t, :, 0])\n\n\nani = FuncAnimation(fig, update_axes, frames=range(ntimes),\n                    interval=50)\n\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}