Reto de métodos numéricos: Día 23
Durante octubre (2017) estaré escribiendo un programa por día para algunos métodos numéricos famosos en Python y Julia. Esto está pensado como un ejercicio, no esperen que el código sea lo suficientemente bueno para usarse en la "vida real". Además, también debo mencionar que casi que no tengo experiencia con Julia, así que probablemente no escriba un Julia idiomático y se parezca más a Python.
Método de Ritz
Hoy tenemos el método de Ritz para resolver la ecuación:
con
El método consiste en formar un funcional que es equivalente a la ecuación diferencial, proponer una aproximación como una combinación lineal de un conjunto de funciones base y encontrar el mejor conjunto de coeficientes para esta combinación. Este mejor solución se encuentra minimizando el funcional.
El funcional para esta ecuación diferencial es
En este caso, estamos usando la aproximación
en donde escogimos el factor \(x (1 - x)\) para forzar que las funciones satisfagan las condiciones de frontera. El funcional aproximado es
en donde, en general, necesitamos realizar una integración numérica para el segundo término.
Minimizando el funcional
obtenmos el siguiente sistema de ecuaciones
con
y
Probaremos la implementación con la función \(f(x) = x^3\), que lleva a la solución
A continuación se presenta el código.
Python
from __future__ import division, print_function import numpy as np from scipy.integrate import quad from scipy.linalg import solve import matplotlib.pyplot as plt def ritz(N, source): stiff_mat = np.zeros((N, N)) rhs = np.zeros((N)) for row in range(N): for col in range(N): numer = (2 + 2*row + 2*col + 2*row*col) denom = (row + col + 1) * (row + col + 2) * (row + col + 3) stiff_mat[row, col] = numer/denom fun = lambda x: x**(row + 1)*(1 - x)*source(x) rhs[row], _ = quad(fun, 0, 1) return stiff_mat, rhs N = 2 source = lambda x: x**3 mat, rhs = ritz(N, source) c = solve(mat, -rhs) x = np.linspace(0, 1, 100) y = np.zeros_like(x) for cont in range(N): y += c[cont]*x**(cont + 1)*(1 - x) #%% Plotting plt.figure(figsize=(4, 3)) plt.plot(x, y) plt.plot(x, x*(x**4 - 1)/20, linestyle="dashed") plt.xlabel(r"$x$") plt.ylabel(r"$y$") plt.legend(["Ritz solution", "Exact solution"]) plt.tight_layout() plt.show()
Julia
using PyPlot function ritz(N, source) stiff_mat = zeros(N, N) rhs = zeros(N) for row in 0:N-1 for col in 0:N-1 numer = (2 + 2*row + 2*col + 2*row*col) denom = (row + col + 1) * (row + col + 2) * (row + col + 3) stiff_mat[row + 1, col + 1] = numer/denom end fun(x) = x^(row + 1)*(1 - x)*source(x) rhs[row + 1], _ = quadgk(fun, 0, 1) end return stiff_mat, rhs end N = 2 source(x) = x^3 mat, rhs = ritz(N, source) c = -mat\rhs x = linspace(0, 1, 100) y = zeros(x) for cont in 0:N - 1 y += c[cont + 1]*x.^(cont + 1).*(1 - x) end #%% Plotting figure(figsize=(4, 3)) plot(x, y) plot(x, x.*(x.^4 - 1)/20, linestyle="dashed") xlabel(L"$x$") ylabel(L"$y$") legend(["Ritz solution", "Exact solution"]) tight_layout() show()
Ambos tiene (casi) el mismo resultado y se muestra a continuación
Y si consideramos 3 términos en la expansion, obtenemos
Comparación Python/Julia
Respecto al número de líneas tenemos: 38 en Python y 38 en Julia. La comparación
en tiempo de ejecución se realizó con el comando mágico de IPython %timeit
y con @benchmark
en Julia.
Para Python:
%%timeit mat, rhs = ritz(5, source) c = solve(mat, -rhs)
con resultado
1000 loops, best of 3: 228 µs per loop
Para Julia:
function bench() mat, rhs = ritz(N, source) c = -mat\rhs end @benchmark bench()
con resultado
BenchmarkTools.Trial: memory estimate: 6.56 KiB allocs estimate: 340 -------------- minimum time: 13.527 μs (0.00% GC) median time: 15.927 μs (0.00% GC) mean time: 17.133 μs (4.50% GC) maximum time: 2.807 ms (97.36% GC) -------------- samples: 10000 evals/sample: 1
En este caso, podemos decir que el código de Python es alrededor de 14 veces más lento que el de Julia.
Comentarios
Comments powered by Disqus