darknet  v3
proverbot.py
Go to the documentation of this file.
1 from darknet import *
2 
3 def predict_tactic(net, s):
4  prob = 0
5  d = c_array(c_float, [0.0]*256)
6  tac = ''
7  if not len(s):
8  s = '\n'
9  for c in s[:-1]:
10  d[ord(c)] = 1
11  pred = predict(net, d)
12  d[ord(c)] = 0
13  c = s[-1]
14  while 1:
15  d[ord(c)] = 1
16  pred = predict(net, d)
17  d[ord(c)] = 0
18  pred = [pred[i] for i in range(256)]
19  ind = sample(pred)
20  c = chr(ind)
21  prob += math.log(pred[ind])
22  if len(tac) and tac[-1] == '.':
23  break
24  tac = tac + c
25  return (tac, prob)
26 
27 def predict_tactics(net, s, n):
28  tacs = []
29  for i in range(n):
30  reset_rnn(net)
31  tacs.append(predict_tactic(net, s))
32  tacs = sorted(tacs, key=lambda x: -x[1])
33  return tacs
34 
35 net = load_net("cfg/coq.test.cfg", "/home/pjreddie/backup/coq.backup", 0)
36 t = predict_tactics(net, "+++++\n", 10)
37 print t
def sample(probs)
Definition: darknet.py:5
void reset_rnn(network *net)
Definition: network.c:85
def predict_tactic(net, s)
Definition: proverbot.py:3
def c_array(ctype, values)
Definition: darknet.py:15
predict
Definition: darknet.py:54
def predict_tactics(net, s, n)
Definition: proverbot.py:27
load_net
Definition: darknet.py:85