初入数值分析,如何写好代码

Author Avatar
琉璃 8月 23, 2020

Sum

求和符号频繁的出现在各种公式里面。比如 Composite Simpson Rule:

$\int^a_b f(x)dx = \frac{h}{3}(f(a)+f(b)+4\sum{n/2}_{i=1}f(a+(2i-1)h)+2\sum{(n-2)/2}_{i=1}f(a+2ih))$

我之前看到求和符号的第一反应是这儿又得用 for loop 了。例如 $\sum^{n/2}_{i=1}f(a+(2i-1)h)$ 可能会用以下的代码来计算。

sum=0
for i in range(1,n//2+1):
  sum = sum+f(a+(2*i-1)*h)
sum

这是一个有点冗长,不清楚的写法。

Map

我们可以用 np.summap 函数来实现。

import numpy as np
np.sum(np.array(list(map(lambda i:f(a+(2*i-1)*h),np.arange(1,n//2+1)))))

瞬间把之前的几行代码压缩成了一行。这样写的坏处是括号比较多,在没有括号高亮的情况下容易出现漏括号或者多括号的情况。

Vectorize

这个方法全靠 Numpy。

import numpy as np
i = np.arange(1,n//2+1)
sum1 = lambda i: f(a+(2*i-1)*h)
vfunc = np.vectorize(sum1)
np.sum(vfunc(i))

这个思路感觉和 map 的思路类似,但是优势是创建了一个可以复用的 Vectorize Function,可以接受数组的输入。

Another For Loop

这个其实和 for loop 没什么区别,知识短了一点。

import numpy as np
np.sum([f(a+(2*i-1)*h) for i in np.arange(1,n//2+1) ])

这个方法会比较灵活。在面对多个参数的时候会比较好用。比如有一个函数 f(a,b,c,d), 只有 c 这个参数需要变化,a,b,d 都是不要变化的。我们可以写这样写:

import numpy as np
[f(a,b,c,d) for c in np.arange(n)]

矩阵点乘

这里我们用一个简单一点的例子。我们需要计算$\sum_{i=0}^{n}a_ib_i$。这个的本质其实是 a 和 b 两个矩阵的点乘。

a=[...]
b=[...]
np.dot(a,b)

快速生成一个矩阵

这个很简单。例如生成一个 5*5 矩阵。

import numpy as np
np.zeros((5, 5))

Table

Tabulate

print(tabulate([["Name","Age"],["Alice",24],["Bob",19]],headers="firstrow"))
# Name      Age
# ------  -----
# Alice      24
# Bob        19
print(tabulate({"Name": ["Alice", "Bob"],
"Age": [24, 19]}, headers="keys"))
#   Age  Name
# -----  ------
#    24  Alice
#    19  Bob

Plotly

import plotly.graph_objects as go

fig = go.Figure(data=[go.Table(header=dict(values=['A Scores', 'B Scores']),
                 cells=dict(values=[[100, 90, 80, 90], [95, 85, 75, 95]]))
                     ])
fig.show()

Plot

我每次都记不住怎么画图。

matplotlib

import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(9,6))
ax.loglog(hs,composite_trapezoid_rule_error_func1,'bo-',label='CTR ERROR',lw=2)
ax.loglog(hs,np.power(hs,2),'ro-',label='err(h) = h^2',lw=2)
ax.set_title("Error in CTR approximation",fontsize=22)
ax.legend(fontsize=15)
ax.set_xlabel('h',fontsize=22)
ax.set_ylabel('err(h)',fontsize=22)
ax.xaxis.set_tick_params(labelsize=15)
ax.yaxis.set_tick_params(labelsize=15)

伪代码翻译

数组的索引从 0 开始,但是很多伪代码是从 1 开始。为了方便翻译,我们可以在所有数组的开头插入一个 0,这样数组的有效数据便从索引 1 开始。这样伪代码和实际代码之间的索引便不会错了。以下为例子。

import numpy as np
def sor(a,b,XO,omega,TOL,N):
    n=len(a)
    x=np.zeros(n+1)
    a = np.insert(a,0,0,axis=1)
    a = np.insert(a,0,0,axis=0)
    b = np.insert(b,0,0,axis=0)    
    XO = [0.0]+XO
    # step 1
    k=1
    # step 2
    while k<=N:
        # step 3
        for i in np.arange(1,n+1):
            sum1=sum([a[i][j]*x[j] for j in np.arange(1,i)])
            sum2=sum([a[i][j]*XO[j] for j in np.arange(i+1,n+1)])
            x[i]=(1-omega)*XO[i]+1/a[i][i]*(omega*(np.negative(sum1)-sum2+b[i]))
        # step 4
        if max(np.abs(b-np.dot(a,x)))<TOL:
            print(f'Number of Iterations: {k}')
            return x[1:]
        # step 5
        k = k+1
        # step 6
        for i in np.arange(1,n+1):
            XO[i]=x[i]
    print('Maximum number of iterations exceed')
    return XO

assert np.allclose(sor([[4,3,0],[3,4,-1],[0,-1,4]],[24,30,-24],[1.0,1.0,1.0],1.25,1e-05,1000),[3,4,-5], rtol=1e-05, atol=1e-08, equal_nan=False)

This blog is under a CC BY-NC-SA 3.0 Unported License
本文链接:https://www.inevitable.tech/posts/e6738e32/