-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathworker.py
More file actions
107 lines (80 loc) · 2.92 KB
/
Copy pathworker.py
File metadata and controls
107 lines (80 loc) · 2.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""Worker that listens to events from the Flask webserver and runs the keras model
"""
import redis
import json, random
import logging, sys
from server.HiveModel import HiveModel
from server.Config import Config
# setup logging
logging.basicConfig(stream=sys.stderr, level=logging.INFO)
# connect redis
r = redis.StrictRedis(host=Config.REDIS_HOST, port=6379, charset="utf-8", decode_responses=True)
# setup model
model = HiveModel(path=Config.IMAGE_PATH, img_size=Config.IMAGE_SIZE)
# listen for labels
logging.info("Listening for new labels")
p = r.pubsub()
p.subscribe('hive_messages')
def run_test(model, r, test_id):
"""Evaluate a single image"""
# evaluate
acc, label = model.evaluate(test_id)
# push to accuracy history and trim list to last n entries
r.lpush('accuracies', acc)
r.ltrim('accuracies', 0, 64)
# save result
r.lset('test_labels', test_id, label)
r.lset('test_scores', test_id, acc)
def reset(r, model, test_ids):
"""Reset keras model"""
# init (or reset) model
model.init_model()
# init redis metrics
r.delete('accuracies')
r.set('annotation_count', 0)
r.delete('test_labels')
r.delete('test_scores')
r.rpush('test_labels', * [-1] * model._test_x.shape[0])
r.rpush('test_scores', * [-1] * model._test_x.shape[0])
# first evaluate all test images to provide baseline for untrained model
for test_id in test_ids:
run_test(model, r, test_id)
# create shuffled list of test_ids
test_ids = list(range(0, model._test_x.shape[0]))
random.shuffle(test_ids)
test_index = 0
# start with clean model and metrics
reset(r, model, test_ids)
while True:
message = p.get_message()
# only pick up real messages (ignore subscribe messages, etc)
if message and message['type'] == 'message':
# unpack data from message
data = json.loads(message['data'])
# annotation message
if data['action'] == 'label':
# perform training cycle
model.label(data['image_id'], data['class_id'])
# increase counter for total number of annotations
r.incr('annotation_count')
# pick next test image to evaluate
test_index += 1
if test_index >= len(test_ids):
test_index = 0
# evaluate
test_id = test_ids[test_index]
run_test(model, r, test_id)
# reset message
if data['action'] == 'reset':
logging.info("Resetting model and metrics")
reset(r, model, test_ids)
# simulate perfect annotations
if data['action'] == 'simulate':
for i, x in enumerate(model._train_x):
class_id = model._train_labels[i]
data = {
'action': 'label',
'image_id': int(i),
'class_id': int(class_id)
}
r.publish('hive_messages', json.dumps(data))