Nahrawy commited on
Commit
478112e
·
1 Parent(s): c605921

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -0
app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import pandas as pd
4
+ import matplotlib.pyplot as plt
5
+ from sklearn import linear_model
6
+
7
+ def plot(seed, num_points):
8
+ # Error handling of non-numeric seeds
9
+ if seed and not seed.isnumeric():
10
+ raise gr.Error("Invalid seed")
11
+
12
+ # Setting the seed
13
+ if seed:
14
+ seed = int(seed)
15
+ np.random.seed(seed)
16
+ num_points = int(num_points)
17
+
18
+ #Ensuring the number of points is even
19
+ if num_points%2 != 0:
20
+ num_points +=1
21
+ half_num_points = int(num_points/2)
22
+
23
+ X = np.r_[np.random.randn(half_num_points, 2) + [1, 1], np.random.randn(half_num_points, 2)]
24
+ y = [1] * half_num_points + [-1] * half_num_points
25
+ sample_weight = 100 * np.abs(np.random.randn(num_points))
26
+ # and assign a bigger weight to the last 10 samples
27
+ sample_weight[:half_num_points] *= 10
28
+
29
+ # plot the weighted data points
30
+ xx, yy = np.meshgrid(np.linspace(-4, 5, 500), np.linspace(-4, 5, 500))
31
+ fig, ax = plt.subplots()
32
+ ax.scatter(
33
+ X[:, 0],
34
+ X[:, 1],
35
+ c=y,
36
+ s=sample_weight,
37
+ alpha=0.9,
38
+ cmap=plt.cm.bone,
39
+ edgecolor="black",
40
+ )
41
+
42
+ # fit the unweighted model
43
+ clf = linear_model.SGDClassifier(alpha=0.01, max_iter=100)
44
+ clf.fit(X, y)
45
+ Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()])
46
+ Z = Z.reshape(xx.shape)
47
+ no_weights = ax.contour(xx, yy, Z, levels=[0], linestyles=["solid"])
48
+
49
+ # fit the weighted model
50
+ clf = linear_model.SGDClassifier(alpha=0.01, max_iter=100)
51
+ clf.fit(X, y, sample_weight=sample_weight)
52
+ Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()])
53
+ Z = Z.reshape(xx.shape)
54
+ samples_weights = ax.contour(xx, yy, Z, levels=[0], linestyles=["dashed"])
55
+
56
+ no_weights_handles, _ = no_weights.legend_elements()
57
+ weights_handles, _ = samples_weights.legend_elements()
58
+ ax.legend(
59
+ [no_weights_handles[0], weights_handles[0]],
60
+ ["no weights", "with weights"],
61
+ loc="lower left",
62
+ )
63
+
64
+ ax.set(xticks=(), yticks=())
65
+ return fig
66
+
67
+ info = ''' # SGD: Weighted samples\n
68
+ This is a demonstration of a modified version of [SGD](https://scikit-learn.org/stable/modules/sgd.html#id5) that takes into account the weights of the samples. Where the size of points is proportional to its weight.\n
69
+ Created by [@Nahrawy](https://huggingface.co/Nahrawy) based on [scikit-learn docs](https://scikit-learn.org/stable/auto_examples/linear_model/plot_sgd_weighted_samples.html).
70
+ '''
71
+
72
+ with gr.Blocks() as demo:
73
+ gr.Markdown(info)
74
+ with gr.Row():
75
+ with gr.Column():
76
+ seed = gr.Textbox(label="Seed", info="Leave empty to generate new random points each run ",value=None)
77
+ num_points = gr.Slider(label="Number of Points", value="20", minimum=5, maximum=100, step=2)
78
+ btn = gr.Button("Run")
79
+ out = gr.Plot()
80
+ btn.click(fn=plot, inputs=[seed,num_points] , outputs=out)
81
+ demo.launch()