Skip to content

Commit

Permalink
reduce model size
Browse files Browse the repository at this point in the history
  • Loading branch information
saengowp committed Aug 27, 2023
1 parent de1e77d commit feeeaf2
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions cgrcompute/components/courserecommendation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def train(observations: list[set[Hashable]]) -> 'CosineSimRecommendationModel':
neigh = []
for c in sim.getrow(i).nonzero()[1]:
neigh.append((items[c], sim[i, c]))
neigh = sorted(neigh, key=lambda x: x[1])[-300:]
neigh = sorted(neigh, key=lambda x: x[1])[-100:]
ccmtx[cid] = dict(neigh)
return CosineSimRecommendationModel(ccmtx)

Expand All @@ -42,7 +42,7 @@ def infer(self, selected_item: list[Hashable]) -> dict[Hashable, float]:
d[pcid] += scr
except KeyError:
pass
return dict(sorted(d.items(), key=lambda x: x[1])[-300:])
return dict(sorted(d.items(), key=lambda x: x[1])[-100:])

class CourseRecommendationModel:

Expand All @@ -60,7 +60,7 @@ def populate(self):
def infer(self, selected_courses):
res = self.model.infer(selected_courses)
res = sorted(res.items(), key=lambda x:-x[1])
return [course for course, score in res][:300]
return [course for course, score in res][:100]

def downloadobsvdata(self, es: ElasticService):
self.logger.info('Download observation')
Expand All @@ -77,7 +77,7 @@ def downloadobsvdata(self, es: ElasticService):
cnt += 1
if cnt % 10000 == 0:
self.logger.info("Downloaded {} observations".format(cnt))
if cnt >= 900000:
if cnt >= 100000:
break
self.logger.info('Received {} observations'.format(cnt))
obsv = [l for _, l in obsv.items() if len(l) > 4]
Expand Down

0 comments on commit feeeaf2

Please sign in to comment.