-
Notifications
You must be signed in to change notification settings - Fork 84
Expand file tree
/
Copy pathConvexity_and_local_extrema.py
More file actions
71 lines (54 loc) · 2.32 KB
/
Convexity_and_local_extrema.py
File metadata and controls
71 lines (54 loc) · 2.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Jan 13 22:56:40 2021
@author: Sarat Moka
Python code for generating 3d-plots of
a convex function with a (unique) global minimum
and a non-convex function with several local extrema.
"""
import numpy as np
import matplotlib.pyplot as plt
textsize = 20
#mycmap = plt.get_cmap('gist_earth')
#mycmap = plt.get_cmap('magma')
# =============================================================================
# Covex and non-convex functions
# =============================================================================
f_non_convex = lambda t: 3*((1 - t[0])**2)*np.exp(-t[0]**2 - (t[1] + 1)**2) - 10*(t[0]/5 - t[0]**3 - t[1]**5)*np.exp(-t[0]**2 - t[1]**2) - (1/3)*np.exp(-(t[0]+1)**2 - t[1]**2)
f_convex = lambda t: t[0]**2 + t[1]**2
#%%
# =============================================================================
# This cell plots a convex function
# =============================================================================
t1_range = np.arange(-2.5, 2.5, 0.01)
t2_range = np.arange(-2.5, 2.5, 0.01)
A = np.meshgrid(t1_range, t2_range)
Z = f_convex(A)
X, Y = np.meshgrid(t1_range, t2_range)
fig = plt.figure()
ax = plt.axes(projection='3d')
ax.plot_surface(X, Y, Z, rstride=1, cstride=1, linewidth=5, antialiased=False, edgecolor='none', vmin =np.min(Z), vmax =2*np.max(Z))
#ax.plot_surface(X, Y, Z, linewidth=2, rcount=50, ccount=50, antialiased=False, edgecolor='none', vmin =np.min(Z), vmax =1.5*np.max(Z))
ax.view_init(22, -56)
ax.set_xlabel(r'$\theta_1$', fontsize=textsize)
ax.set_ylabel(r'$\theta_2$', fontsize=textsize)
#plt.xticks(size=textsize)
#plt.yticks(size=textsize)
plt.show()
#%%
# =============================================================================
# This cell plots a non-convex function
# =============================================================================
t1_range = np.arange(-2.5, 2.5, 0.01)
t2_range = np.arange(-2.5, 2.5, 0.01)
A = np.meshgrid(t1_range, t2_range)
Z = f_non_convex(A)
X, Y = np.meshgrid(t1_range, t2_range)
fig = plt.figure()
ax = plt.axes(projection='3d')
ax.plot_surface(X, Y, Z, rstride=1, cstride=1, linewidth=5, antialiased=False, edgecolor='none', vmin =np.min(Z), vmax =2*np.max(Z), alpha=1)
ax.view_init(22, -56)
ax.set_xlabel(r'$\theta_1$', fontsize=textsize)
ax.set_ylabel(r'$\theta_2$', fontsize=textsize)
plt.show()