From 40364735e0e2f1cbecdb491e7fe4b6c237cc69e1 Mon Sep 17 00:00:00 2001 From: AB Date: Fri, 28 Feb 2020 09:53:00 +0000 Subject: [PATCH] Improve markov chain text generator. --- database.py | 10 ++++++++-- worker.py | 10 ++++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/database.py b/database.py index fce0805..069c596 100755 --- a/database.py +++ b/database.py @@ -165,8 +165,14 @@ class DataBase: result = self.execute(sql) return(result) - def get_random_message(self): - sql = "SELECT text FROM xxx_message ORDER BY RANDOM() LIMIT 1" + def get_random_message(self, conf_id=None, count=1): + if not conf_id: + print('get random message from all DB') + sql = "SELECT text FROM xxx_message ORDER BY RANDOM() LIMIT %s" % count + else: + print('get random message from %s ' % conf_id) + sql = """SELECT x.text FROM xxx_message x LEFT JOIN relations r ON r.msg_id == x.id + WHERE r.conf_id = '%s' ORDER BY RANDOM() DESC LIMIT 1""" % conf_id result = self.execute(sql) return(result[0][0]) diff --git a/worker.py b/worker.py index aed9a78..402ed1d 100755 --- a/worker.py +++ b/worker.py @@ -196,8 +196,14 @@ class MessageWorker: count = max_sen except: count = 5 - for i in range(0, count): - rand_messages.append(self.db.get_random_message()) + try: + use_all = bool(msg['message']['text'][8:]) + except: + use_all = False + if use_all: + rand_messages.append(self.db.get_random_message(count=count)) + else: + rand_messages.append(self.db.get_random_message(conf_id, count=count)) rand_text = " ".join(rand_messages) gen_text = get(rand_text) try: