fabiochiusano commited on
Commit
0ba9aa2
1 Parent(s): b7aa506

first commit

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ myvenv
app.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import streamlit.components.v1 as components
3
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
+ import utils
5
+ from kb import KB
6
+
7
+ texts = {
8
+ "Napoleon": "Napoleon Bonaparte (born Napoleone di Buonaparte; 15 August 1769 – 5 May 1821), and later known by his regnal name Napoleon I, was a French military and political leader who rose to prominence during the French Revolution and led several successful campaigns during the Revolutionary Wars. He was the de facto leader of the French Republic as First Consul from 1799 to 1804. As Napoleon I, he was Emperor of the French from 1804 until 1814 and again in 1815. Napoleon's political and cultural legacy has endured, and he has been one of the most celebrated and controversial leaders in world history.",
9
+ "Kobe Bryant": "Kobe Bean Bryant (August 23, 1978 – January 26, 2020) was an American professional basketball player. A shooting guard, he spent his entire 20-year career with the Los Angeles Lakers in the National Basketball Association (NBA). Widely regarded as one of the greatest basketball players of all time, Bryant won five NBA championships, was an 18-time All-Star, a 15-time member of the All-NBA Team, a 12-time member of the All-Defensive Team, the 2008 NBA Most Valuable Player (MVP), and a two-time NBA Finals MVP. Bryant also led the NBA in scoring twice, and ranks fourth in league all-time regular season and postseason scoring. He was posthumously voted into the Naismith Memorial Basketball Hall of Fame in 2020 and named to the NBA 75th Anniversary Team in 2021.",
10
+ "Google": "Originally known as BackRub. Google is a search engine that started development in 1996 by Sergey Brin and Larry Page as a research project at Stanford University to find files on the Internet. Larry and Sergey later decided the name of their search engine needed to change and chose Google, which is inspired from the term googol. The company is headquartered in Mountain View, California."
11
+ }
12
+
13
+ urls = {
14
+ "Crypto": "https://www.investopedia.com/terms/c/cryptocurrency.asp",
15
+ "Jhonny Depp": "https://www.britannica.com/biography/Johnny-Depp",
16
+ "Rome": "https://www.timeout.com/rome/things-to-do/best-things-to-do-in-rome"
17
+ }
18
+
19
+ st.header("Extracting a Knowledge Base from text")
20
+ st_model_load = st.text('Loading NER model... It may take a while.')
21
+
22
+ @st.cache(allow_output_mutation=True)
23
+ def load_model():
24
+ print("Loading model...")
25
+ tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
26
+ model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")
27
+ print("Model loaded!")
28
+ return tokenizer, model
29
+
30
+ tokenizer, model = load_model()
31
+ st.success('Model loaded!')
32
+ st_model_load.text("")
33
+
34
+ # sidebar
35
+ with st.sidebar:
36
+ st.header("What is a Knowledge Base")
37
+ st.markdown("A [**Knowledge Base (KB)**](https://en.wikipedia.org/wiki/Knowledge_base) is information stored in structured data, ready to be used for analysis or inference. Usually a KB is stored as a graph (i.e. a [**Knowledge Graph**](https://www.ibm.com/cloud/learn/knowledge-graph)), where nodes are **entities** and edges are **relations** between entities.")
38
+ st.markdown("_For example, the from the text \"Fabio lives in Italy\" we can extract the relation triplet <Fabio, lives in, Italy>, where \"Fabio\" and \"Italy\" are entities._")
39
+ st.header("How to build a Knowledge Graph")
40
+ st.markdown("To build a Knowledge Graph from text, we typically need to perform two steps:\n- Extract entities, a.k.a. **Named Entity Recognition (NER)**, i.e. the nodes.\n- Extract relations between the entities, a.k.a. **Relation Classification (RC)**, i.e. the edges.\nRecently, end-to-end approaches have been proposed to tackle both tasks simultaneously. This task is usually referred to as **Relation Extraction (RE)**. In this demo, an end-to-end model called [**REBEL**](https://github.com/Babelscape/rebel/blob/main/docs/EMNLP_2021_REBEL__Camera_Ready_.pdf) is used.")
41
+ st.header("How REBEL works")
42
+ st.markdown("REBEL is a **text2text** model obtained by fine-tuning [**BART**](https://huggingface.co/docs/transformers/model_doc/bart) for translating a raw input sentence containing entities and implicit relations into a set of triplets that explicitly refer to those relations. You can find [REBEL in the Hugging Face Hub](https://huggingface.co/Babelscape/rebel-large).")
43
+ st.header("Further steps")
44
+ st.markdown("Even though they are not visualized, the knowledge graph saves information about the provenience of each relation (e.g. from which articles it has been exrtacted and other metadata), along with Wikipedia data about each entity.")
45
+ st.markdown("Other libraries used:\n- [wikipedia](https://pypi.org/project/wikipedia/): For validating extracted entities checking if they have a corresponding Wikipedia page.\n- [newspaper](https://github.com/codelucas/newspaper): For parsing articles from URLs.\n- [pyvis](https://pyvis.readthedocs.io/en/latest/index.html): For graphs visualizations.\n- [GoogleNews](https://github.com/Iceloof/GoogleNews): For reading Google News latest articles about a topic.")
46
+ st.header("Considerations")
47
+ st.markdown("If you look closely at the extracted knowledge graphs, some extracted relations are false. Indeed, relation extraction models are still far from perfect and require further steps in the pipeline to build reliable knowledge graphs. Consider this demo as a starting step!")
48
+
49
+ # Choose from where to generate the KB
50
+ options = [
51
+ "Text",
52
+ "Article at URL",
53
+ "Multiple news articles"
54
+ ]
55
+ if 'option' not in st.session_state:
56
+ st.session_state.option = options[0]
57
+ option = st.selectbox('Build a Knowledge Base from:', options, index=options.index(st.session_state.option))
58
+
59
+ text_option, text = None, None
60
+ url_option, url = None, None
61
+ news_option = None
62
+
63
+ if option == "Text":
64
+ text_options = [
65
+ "Napoleon",
66
+ "Kobe Bryant",
67
+ "Google",
68
+ "Free text"
69
+ ]
70
+ if 'text_option' not in st.session_state or st.session_state.text_option is None:
71
+ st.session_state.text_option = text_options[0]
72
+ text_option = st.selectbox('Choose text option:', text_options, index=text_options.index(st.session_state.text_option))
73
+
74
+ disabled = False
75
+ if text_option != "Free text":
76
+ disabled = True
77
+ text = texts[text_option]
78
+ else:
79
+ if 'text' not in st.session_state:
80
+ st.session_state.text = ""
81
+ text = st.session_state.text
82
+ text = st.text_area('Text:', value=text, height=300, disabled=disabled)
83
+ elif option == "Article at URL":
84
+ url_options = [
85
+ "Crypto",
86
+ "Jhonny Depp",
87
+ "Rome",
88
+ "Free URL"
89
+ ]
90
+ if 'url_option' not in st.session_state or st.session_state.url_option is None:
91
+ st.session_state.url_option = url_options[0]
92
+ url_option = st.selectbox('Choose URL option:', url_options, index=url_options.index(st.session_state.url_option))
93
+
94
+ disabled = False
95
+ if url_option != "Free URL":
96
+ disabled = True
97
+ url = urls[url_option]
98
+ else:
99
+ if 'url' not in st.session_state:
100
+ st.session_state.url = ""
101
+ url = st.session_state.url
102
+ url = st.text_input('URL:', value=url, disabled=disabled)
103
+ else:
104
+ news_options = [
105
+ "Google",
106
+ "Apple",
107
+ "Elon Musk",
108
+ "Kobe Bryant"
109
+ ]
110
+ if 'news_option' not in st.session_state or st.session_state.news_option is None:
111
+ st.session_state.news_option = news_options[0]
112
+ news_option = st.selectbox('Use articles about:', news_options, index=news_options.index(st.session_state.news_option))
113
+
114
+ placeholder = st.empty()
115
+
116
+ def generate_kb():
117
+ st.session_state.option = option
118
+ st.session_state.text_option = text_option
119
+ st.session_state.text = text
120
+ st.session_state.url_option = url_option
121
+ st.session_state.url = url
122
+ st.session_state.news_option = news_option
123
+
124
+ # load correct kb
125
+ if option == "Text":
126
+ if text_option == "Napoleon":
127
+ kb = utils.load_kb("networks/network_1_napoleon.p")
128
+ elif text_option == "Kobe Bryant":
129
+ kb = utils.load_kb("networks/network_1_bryant.p")
130
+ elif text_option == "Google":
131
+ kb = utils.load_kb("networks/network_1_google.p")
132
+ else:
133
+ kb = utils.from_text_to_kb(text, model, tokenizer, "", verbose=True)
134
+ elif option == "Article at URL":
135
+ if url_option == "Crypto":
136
+ kb = utils.load_kb("networks/network_2_crypto.p")
137
+ elif url_option == "Jhonny Depp":
138
+ kb = utils.load_kb("networks/network_2_depp.p")
139
+ elif url_option == "Rome":
140
+ kb = utils.load_kb("networks/network_2_rome.p")
141
+ else:
142
+ kb = utils.from_url_to_kb(url, model, tokenizer)
143
+ else:
144
+ if news_option == "Google":
145
+ kb = utils.load_kb("networks/network_3_google.p")
146
+ elif news_option == "Apple":
147
+ kb = utils.load_kb("networks/network_3_apple.p")
148
+ elif news_option == "Elon Musk":
149
+ kb = utils.load_kb("networks/network_3_musk.p")
150
+ elif news_option == "Kobe Bryant":
151
+ kb = utils.load_kb("networks/network_3_bryant.p")
152
+
153
+ # save chart
154
+ utils.save_network_html(kb, filename="networks/network.html")
155
+ st.session_state.kb_chart = "networks/network.html"
156
+ st.session_state.kb_text = kb.get_textual_representation()
157
+
158
+
159
+ st.session_state.option = option
160
+ st.session_state.text_option = text_option
161
+ st.session_state.text = text
162
+ st.session_state.url_option = url_option
163
+ st.session_state.url = url
164
+ st.session_state.news_option = news_option
165
+
166
+ button_text = "Show KB"
167
+ if (option == "Text" and text_option == "Free text") or (option == "Article at URL" and url_option == "Free URL"):
168
+ button_text = "Generate KB"
169
+
170
+ # generate KB button
171
+ st.button(button_text, on_click=generate_kb)
172
+
173
+ # kb chart session state
174
+ if 'kb_chart' not in st.session_state:
175
+ st.session_state.kb_chart = None
176
+ if 'kb_text' not in st.session_state:
177
+ st.session_state.kb_text = None
178
+
179
+ # show graph
180
+ if st.session_state.kb_chart:
181
+ with st.container():
182
+ st.subheader("Generated KB")
183
+ st.markdown("*You can interact with the graph and zoom.*")
184
+ html_source_code = open(st.session_state.kb_chart, 'r', encoding='utf-8').read()
185
+ components.html(html_source_code, width=700, height=700)
186
+ st.markdown(st.session_state.kb_text)
kb.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import wikipedia
2
+
3
+ class KB():
4
+ def __init__(self):
5
+ self.entities = {} # { entity_title: {...} }
6
+ self.relations = [] # [ head: entity_title, type: ..., tail: entity_title,
7
+ # meta: { article_url: { spans: [...] } } ]
8
+ self.sources = {} # { article_url: {...} }
9
+
10
+ def merge_with_kb(self, kb2):
11
+ for r in kb2.relations:
12
+ article_url = list(r["meta"].keys())[0]
13
+ source_data = kb2.sources[article_url]
14
+ self.add_relation(r, source_data["article_title"],
15
+ source_data["article_publish_date"])
16
+
17
+ def are_relations_equal(self, r1, r2):
18
+ return all(r1[attr] == r2[attr] for attr in ["head", "type", "tail"])
19
+
20
+ def exists_relation(self, r1):
21
+ return any(self.are_relations_equal(r1, r2) for r2 in self.relations)
22
+
23
+ def merge_relations(self, r2):
24
+ r1 = [r for r in self.relations
25
+ if self.are_relations_equal(r2, r)][0]
26
+
27
+ # if different article
28
+ article_url = list(r2["meta"].keys())[0]
29
+ if article_url not in r1["meta"]:
30
+ r1["meta"][article_url] = r2["meta"][article_url]
31
+
32
+ # if existing article
33
+ else:
34
+ spans_to_add = [span for span in r2["meta"][article_url]["spans"]
35
+ if span not in r1["meta"][article_url]["spans"]]
36
+ r1["meta"][article_url]["spans"] += spans_to_add
37
+
38
+ def get_wikipedia_data(self, candidate_entity):
39
+ try:
40
+ page = wikipedia.page(candidate_entity, auto_suggest=False)
41
+ entity_data = {
42
+ "title": page.title,
43
+ "url": page.url,
44
+ "summary": page.summary
45
+ }
46
+ return entity_data
47
+ except:
48
+ return None
49
+
50
+ def add_entity(self, e):
51
+ self.entities[e["title"]] = {k:v for k,v in e.items() if k != "title"}
52
+
53
+ def add_relation(self, r, article_title, article_publish_date):
54
+ # check on wikipedia
55
+ candidate_entities = [r["head"], r["tail"]]
56
+ entities = [self.get_wikipedia_data(ent) for ent in candidate_entities]
57
+
58
+ # if one entity does not exist, stop
59
+ if any(ent is None for ent in entities):
60
+ return
61
+
62
+ # manage new entities
63
+ for e in entities:
64
+ self.add_entity(e)
65
+
66
+ # rename relation entities with their wikipedia titles
67
+ r["head"] = entities[0]["title"]
68
+ r["tail"] = entities[1]["title"]
69
+
70
+ # add source if not in kb
71
+ article_url = list(r["meta"].keys())[0]
72
+ if article_url not in self.sources:
73
+ self.sources[article_url] = {
74
+ "article_title": article_title,
75
+ "article_publish_date": article_publish_date
76
+ }
77
+
78
+ # manage new relation
79
+ if not self.exists_relation(r):
80
+ self.relations.append(r)
81
+ else:
82
+ self.merge_relations(r)
83
+
84
+ def get_textual_representation(self):
85
+ res = ""
86
+ res += "### Entities\n"
87
+ for e in self.entities.items():
88
+ # shorten summary
89
+ e_temp = (e[0], {k:(v[:100] + "..." if k == "summary" else v) for k,v in e[1].items()})
90
+ res += f"- {e_temp}\n"
91
+ res += "\n"
92
+ res += "### Relations\n"
93
+ for r in self.relations:
94
+ res += f"- {r}\n"
95
+ res += "\n"
96
+ res += "### Sources\n"
97
+ for s in self.sources.items():
98
+ res += f"- {s}\n"
99
+ return res
networks/.DS_Store ADDED
Binary file (6.15 kB). View file
 
networks/network_1_bryant.p ADDED
Binary file (20.9 kB). View file
 
networks/network_1_google.p ADDED
Binary file (11.2 kB). View file
 
networks/network_1_napoleon.p ADDED
Binary file (11.9 kB). View file
 
networks/network_2_crypto.p ADDED
Binary file (37.7 kB). View file
 
networks/network_2_depp.p ADDED
Binary file (7.83 kB). View file
 
networks/network_2_rome.p ADDED
Binary file (4.92 kB). View file
 
networks/network_3_amazon.p ADDED
Binary file (153 kB). View file
 
networks/network_3_apple.p ADDED
Binary file (227 kB). View file
 
networks/network_3_bryant.p ADDED
Binary file (185 kB). View file
 
networks/network_3_google.p ADDED
Binary file (190 kB). View file
 
networks/network_3_musk.p ADDED
Binary file (113 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ pyvis
4
+ GoogleNews
5
+ newspaper3k
6
+ wikipedia
utils.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pyvis.network import Network
2
+ from GoogleNews import GoogleNews
3
+ from newspaper import Article, ArticleException
4
+ import math
5
+ import torch
6
+ from kb import KB
7
+ import pickle
8
+
9
+ def extract_relations_from_model_output(text):
10
+ relations = []
11
+ relation, subject, relation, object_ = '', '', '', ''
12
+ text = text.strip()
13
+ current = 'x'
14
+ text_replaced = text.replace("<s>", "").replace("<pad>", "").replace("</s>", "")
15
+ for token in text_replaced.split():
16
+ if token == "<triplet>":
17
+ current = 't'
18
+ if relation != '':
19
+ relations.append({
20
+ 'head': subject.strip(),
21
+ 'type': relation.strip(),
22
+ 'tail': object_.strip()
23
+ })
24
+ relation = ''
25
+ subject = ''
26
+ elif token == "<subj>":
27
+ current = 's'
28
+ if relation != '':
29
+ relations.append({
30
+ 'head': subject.strip(),
31
+ 'type': relation.strip(),
32
+ 'tail': object_.strip()
33
+ })
34
+ object_ = ''
35
+ elif token == "<obj>":
36
+ current = 'o'
37
+ relation = ''
38
+ else:
39
+ if current == 't':
40
+ subject += ' ' + token
41
+ elif current == 's':
42
+ object_ += ' ' + token
43
+ elif current == 'o':
44
+ relation += ' ' + token
45
+ if subject != '' and relation != '' and object_ != '':
46
+ relations.append({
47
+ 'head': subject.strip(),
48
+ 'type': relation.strip(),
49
+ 'tail': object_.strip()
50
+ })
51
+ return relations
52
+
53
+ def from_text_to_kb(text, model, tokenizer, article_url, span_length=128, article_title=None,
54
+ article_publish_date=None, verbose=False):
55
+ # tokenize whole text
56
+ inputs = tokenizer([text], return_tensors="pt")
57
+
58
+ # compute span boundaries
59
+ num_tokens = len(inputs["input_ids"][0])
60
+ if verbose:
61
+ print(f"Input has {num_tokens} tokens")
62
+ num_spans = math.ceil(num_tokens / span_length)
63
+ if verbose:
64
+ print(f"Input has {num_spans} spans")
65
+ overlap = math.ceil((num_spans * span_length - num_tokens) /
66
+ max(num_spans - 1, 1))
67
+ spans_boundaries = []
68
+ start = 0
69
+ for i in range(num_spans):
70
+ spans_boundaries.append([start + span_length * i,
71
+ start + span_length * (i + 1)])
72
+ start -= overlap
73
+ if verbose:
74
+ print(f"Span boundaries are {spans_boundaries}")
75
+
76
+ # transform input with spans
77
+ tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]]
78
+ for boundary in spans_boundaries]
79
+ tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]]
80
+ for boundary in spans_boundaries]
81
+ inputs = {
82
+ "input_ids": torch.stack(tensor_ids),
83
+ "attention_mask": torch.stack(tensor_masks)
84
+ }
85
+
86
+ # generate relations
87
+ num_return_sequences = 3
88
+ gen_kwargs = {
89
+ "max_length": 256,
90
+ "length_penalty": 0,
91
+ "num_beams": 3,
92
+ "num_return_sequences": num_return_sequences
93
+ }
94
+ generated_tokens = model.generate(
95
+ **inputs,
96
+ **gen_kwargs,
97
+ )
98
+
99
+ # decode relations
100
+ decoded_preds = tokenizer.batch_decode(generated_tokens,
101
+ skip_special_tokens=False)
102
+
103
+ # create kb
104
+ kb = KB()
105
+ i = 0
106
+ for sentence_pred in decoded_preds:
107
+ current_span_index = i // num_return_sequences
108
+ relations = extract_relations_from_model_output(sentence_pred)
109
+ for relation in relations:
110
+ relation["meta"] = {
111
+ article_url: {
112
+ "spans": [spans_boundaries[current_span_index]]
113
+ }
114
+ }
115
+ kb.add_relation(relation, article_title, article_publish_date)
116
+ i += 1
117
+
118
+ return kb
119
+
120
+ def get_article(url):
121
+ article = Article(url)
122
+ article.download()
123
+ article.parse()
124
+ return article
125
+
126
+ def from_url_to_kb(url, model, tokenizer):
127
+ article = get_article(url)
128
+ config = {
129
+ "article_title": article.title,
130
+ "article_publish_date": article.publish_date
131
+ }
132
+ kb = from_text_to_kb(article.text, model, tokenizer, article.url, **config)
133
+ return kb
134
+
135
+ def get_news_links(query, lang="en", region="US", pages=1):
136
+ googlenews = GoogleNews(lang=lang, region=region)
137
+ googlenews.search(query)
138
+ all_urls = []
139
+ for page in range(pages):
140
+ googlenews.get_page(page)
141
+ all_urls += googlenews.get_links()
142
+ return list(set(all_urls))
143
+
144
+ def from_urls_to_kb(urls, model, tokenizer, verbose=False):
145
+ kb = KB()
146
+ if verbose:
147
+ print(f"{len(urls)} links to visit")
148
+ for url in urls:
149
+ if verbose:
150
+ print(f"Visiting {url}...")
151
+ try:
152
+ kb_url = from_url_to_kb(url, model, tokenizer)
153
+ kb.merge_with_kb(kb_url)
154
+ except ArticleException:
155
+ if verbose:
156
+ print(f" Couldn't download article at url {url}")
157
+ return kb
158
+
159
+ def save_network_html(kb, filename="network.html"):
160
+ # create network
161
+ net = Network(directed=True, width="700px", height="700px")
162
+
163
+ # nodes
164
+ color_entity = "#00FF00"
165
+ for e in kb.entities:
166
+ net.add_node(e, shape="circle", color=color_entity)
167
+
168
+ # edges
169
+ for r in kb.relations:
170
+ net.add_edge(r["head"], r["tail"],
171
+ title=r["type"], label=r["type"])
172
+
173
+ # save network
174
+ net.repulsion(
175
+ node_distance=200,
176
+ central_gravity=0.2,
177
+ spring_length=200,
178
+ spring_strength=0.05,
179
+ damping=0.09
180
+ )
181
+ net.set_edge_smooth('dynamic')
182
+ net.show(filename)
183
+
184
+ def save_kb(kb, filename):
185
+ with open(filename, "wb") as f:
186
+ pickle.dump(kb, f)
187
+
188
+ class CustomUnpickler(pickle.Unpickler):
189
+ def find_class(self, module, name):
190
+ if name == 'KB':
191
+ return KB
192
+ return super().find_class(module, name)
193
+
194
+ def load_kb(filename):
195
+ res = None
196
+ with open(filename, "rb") as f:
197
+ res = CustomUnpickler(f).load()
198
+ return res