Fix bambu client connection.

This commit is contained in:
Anton Skrypnyk 2024-07-24 17:15:47 +03:00
parent ed33fd8fb1
commit 38a6f58306
4 changed files with 65 additions and 61 deletions

View File

@ -5,12 +5,10 @@ __license__ = "GNU Affero General Public License http://www.gnu.org/licenses/agp
import collections import collections
from dataclasses import dataclass, field from dataclasses import dataclass, field
import math import math
import os
import queue import queue
import re import re
import threading import threading
import time import time
import asyncio
from octoprint_bambu_printer.printer.print_job import PrintJob from octoprint_bambu_printer.printer.print_job import PrintJob
from pybambu import BambuClient, commands from pybambu import BambuClient, commands
import logging import logging
@ -99,13 +97,10 @@ class BambuVirtualPrinter:
self._serial_io.start() self._serial_io.start()
self._printer_thread.start() self._printer_thread.start()
self._bambu_client: BambuClient = None self._bambu_client: BambuClient = self._create_client_connection_async()
asyncio.get_event_loop().run_until_complete(self._create_connection_async())
@property @property
def bambu_client(self): def bambu_client(self):
if self._bambu_client is None:
raise ValueError("No connection to Bambulab was established")
return self._bambu_client return self._bambu_client
@property @property
@ -133,7 +128,7 @@ class BambuVirtualPrinter:
def update_print_job_info(self): def update_print_job_info(self):
print_job_info = self.bambu_client.get_device().print_job print_job_info = self.bambu_client.get_device().print_job
filename: str = print_job_info.get("subtask_name") filename: str = print_job_info.subtask_name
project_file_info = self.file_system.get_data_by_suffix( project_file_info = self.file_system.get_data_by_suffix(
filename, [".3mf", ".gcode.3mf"] filename, [".3mf", ".gcode.3mf"]
) )
@ -146,7 +141,7 @@ class BambuVirtualPrinter:
self.sendOk() self.sendOk()
# fuzzy math here to get print percentage to match BambuStudio # fuzzy math here to get print percentage to match BambuStudio
progress = print_job_info.get("print_percentage") progress = print_job_info.print_percentage
self._current_print_job = PrintJob(project_file_info, 0) self._current_print_job = PrintJob(project_file_info, 0)
self._current_print_job.progress = progress self._current_print_job.progress = progress
@ -195,20 +190,27 @@ class BambuVirtualPrinter:
self._log.debug(f"on connect called") self._log.debug(f"on connect called")
return on_connect return on_connect
async def _create_connection_async(self): def _create_client_connection_async(self):
self._create_client_connection()
if self._bambu_client is None:
raise RuntimeError("Connection with Bambu Client not established")
return self._bambu_client
def _create_client_connection(self):
if ( if (
self._settings.get(["device_type"]) == "" self._settings.get(["device_type"]) == ""
or self._settings.get(["serial"]) == "" or self._settings.get(["serial"]) == ""
or self._settings.get(["username"]) == "" or self._settings.get(["username"]) == ""
or self._settings.get(["access_code"]) == "" or self._settings.get(["access_code"]) == ""
): ):
self._log.debug("invalid settings to start connection with Bambu Printer") msg = "invalid settings to start connection with Bambu Printer"
return self._log.debug(msg)
raise ValueError(msg)
self._log.debug( self._log.debug(
f"connecting via local mqtt: {self._settings.get_boolean(['local_mqtt'])}" f"connecting via local mqtt: {self._settings.get_boolean(['local_mqtt'])}"
) )
self._bambu_client = BambuClient( bambu_client = BambuClient(
device_type=self._settings.get(["device_type"]), device_type=self._settings.get(["device_type"]),
serial=self._settings.get(["serial"]), serial=self._settings.get(["serial"]),
host=self._settings.get(["host"]), host=self._settings.get(["host"]),
@ -223,18 +225,17 @@ class BambuVirtualPrinter:
email=self._settings.get(["email"]), email=self._settings.get(["email"]),
auth_token=self._settings.get(["auth_token"]), auth_token=self._settings.get(["auth_token"]),
) )
self._bambu_client.on_disconnect = self.on_disconnect( bambu_client.on_disconnect = self.on_disconnect(bambu_client.on_disconnect)
self._bambu_client.on_disconnect bambu_client.on_connect = self.on_connect(bambu_client.on_connect)
) bambu_client.connect(callback=self.new_update)
self._bambu_client.on_connect = self.on_connect(self._bambu_client.on_connect) self._log.info(f"bambu connection status: {bambu_client.connected}")
self._bambu_client.connect(callback=self.new_update)
self._log.info(f"bambu connection status: {self._bambu_client.connected}")
self._serial_io.sendOk() self._serial_io.sendOk()
self._bambu_client = bambu_client
def __str__(self): def __str__(self):
return "BAMBU(read_timeout={read_timeout},write_timeout={write_timeout},options={options})".format( return "BAMBU(read_timeout={read_timeout},write_timeout={write_timeout},options={options})".format(
read_timeout=self._read_timeout, read_timeout=self.timeout,
write_timeout=self._write_timeout, write_timeout=self.write_timeout,
options={ options={
"device_type": self._settings.get(["device_type"]), "device_type": self._settings.get(["device_type"]),
"host": self._settings.get(["host"]), "host": self._settings.get(["host"]),
@ -258,21 +259,21 @@ class BambuVirtualPrinter:
@property @property
def timeout(self): def timeout(self):
return self._read_timeout return self._serial_io._read_timeout
@timeout.setter @timeout.setter
def timeout(self, value): def timeout(self, value):
self._log.debug(f"Setting read timeout to {value}s") self._log.debug(f"Setting read timeout to {value}s")
self._read_timeout = value self._serial_io._read_timeout = value
@property @property
def write_timeout(self): def write_timeout(self):
return self._write_timeout return self._serial_io._write_timeout
@write_timeout.setter @write_timeout.setter
def write_timeout(self, value): def write_timeout(self, value):
self._log.debug(f"Setting write timeout to {value}s") self._log.debug(f"Setting write timeout to {value}s")
self._write_timeout = value self._serial_io._write_timeout = value
@property @property
def port(self): def port(self):
@ -505,6 +506,8 @@ class BambuVirtualPrinter:
self._state_change_queue.join() self._state_change_queue.join()
def _printer_worker(self): def _printer_worker(self):
self._create_client_connection_async()
self.sendIO("Printer connection complete")
while self._running: while self._running:
try: try:
next_state = self._state_change_queue.get(timeout=0.01) next_state = self._state_change_queue.get(timeout=0.01)

View File

@ -6,6 +6,7 @@ import queue
import re import re
import threading import threading
import time import time
import traceback
from typing import Callable from typing import Callable
from octoprint.util import to_bytes, to_unicode from octoprint.util import to_bytes, to_unicode
@ -35,6 +36,7 @@ class PrinterSerialIO(threading.Thread, BufferedIOBase):
self._read_timeout = read_timeout self._read_timeout = read_timeout
self._write_timeout = write_timeout self._write_timeout = write_timeout
self.current_line = 0
self._received_lines = 0 self._received_lines = 0
self._wait_interval = 5.0 self._wait_interval = 5.0
self._running = True self._running = True
@ -42,7 +44,7 @@ class PrinterSerialIO(threading.Thread, BufferedIOBase):
self._rx_buffer_size = 64 self._rx_buffer_size = 64
self._incoming_lock = threading.RLock() self._incoming_lock = threading.RLock()
self.input_bytes = CharCountingQueue(self._rx_buffer_size, name="RxBuffer") self.input_bytes = queue.Queue(self._rx_buffer_size)
self.output_bytes = queue.Queue() self.output_bytes = queue.Queue()
self._error_detected: Exception | None = None self._error_detected: Exception | None = None
@ -77,8 +79,8 @@ class PrinterSerialIO(threading.Thread, BufferedIOBase):
except Exception as e: except Exception as e:
self._error_detected = e self._error_detected = e
self.input_bytes.task_done() self.input_bytes.task_done()
self.input_bytes.clear() self._clearQueue(self.input_bytes)
break self._log.info("\n".join(traceback.format_exception(e)[-50:]))
self._log.debug("Closing IO read loop") self._log.debug("Closing IO read loop")
@ -98,6 +100,9 @@ class PrinterSerialIO(threading.Thread, BufferedIOBase):
def flush(self): def flush(self):
self.input_bytes.join() self.input_bytes.join()
self.raise_if_error()
def raise_if_error(self):
if self._error_detected is not None: if self._error_detected is not None:
raise self._error_detected raise self._error_detected
@ -110,11 +115,9 @@ class PrinterSerialIO(threading.Thread, BufferedIOBase):
return 0 return 0
try: try:
written = self.input_bytes.put( self.input_bytes.put(data, timeout=self._write_timeout)
data, timeout=self._write_timeout, partial=True
)
self._log.debug(f"<<< {u_data}") self._log.debug(f"<<< {u_data}")
return written return len(data)
except queue.Full: except queue.Full:
self._log.error( self._log.error(
"Incoming queue is full, raising SerialTimeoutException" "Incoming queue is full, raising SerialTimeoutException"
@ -169,29 +172,11 @@ class PrinterSerialIO(threading.Thread, BufferedIOBase):
self.send(self._format_error("checksum_missing")) self.send(self._format_error("checksum_missing"))
return return
# track N = N + 1 line = self._process_linenumber_marker(data)
linenumber = 0 if line is None:
if data.startswith(b"N") and b"M110" in data:
linenumber = int(re.search(b"N([0-9]+)", data).group(1))
self.lastN = linenumber
self.current_line = linenumber
self.sendOk()
return return
elif data.startswith(b"N"):
linenumber = int(re.search(b"N([0-9]+)", data).group(1))
expected = self.lastN + 1
if linenumber != expected:
self._triggerResend(actual=linenumber)
return
else:
self.lastN = linenumber
data = data.split(None, 1)[1].strip()
data += b"\n"
command = to_unicode(data, encoding="ascii", errors="replace").strip()
command = to_unicode(line, encoding="ascii", errors="replace").strip()
command_match = self.command_regex.match(command) command_match = self.command_regex.match(command)
if command_match is not None: if command_match is not None:
gcode = command_match.group(0) gcode = command_match.group(0)
@ -199,6 +184,25 @@ class PrinterSerialIO(threading.Thread, BufferedIOBase):
else: else:
self._log.warn(f'Not a valid gcode command "{command}"') self._log.warn(f'Not a valid gcode command "{command}"')
def _process_linenumber_marker(self, data: bytes):
linenumber = 0
if data.startswith(b"N") and b"M110" in data:
linenumber = int(re.search(b"N([0-9]+)", data).group(1))
self.lastN = linenumber
self.current_line = linenumber
self.sendOk()
return None
elif data.startswith(b"N"):
linenumber = int(re.search(b"N([0-9]+)", data).group(1))
expected = self.lastN + 1
if linenumber != expected:
self._triggerResend(actual=linenumber)
return None
else:
self.lastN = linenumber
data = data.split(None, 1)[1].strip()
return data
def _triggerResend( def _triggerResend(
self, self,
expected: int | None = None, expected: int | None = None,

View File

@ -90,7 +90,7 @@ class RemoteSDCardFileList:
def _connect_ftps_server(self): def _connect_ftps_server(self):
host = self._settings.get(["host"]) host = self._settings.get(["host"])
access_code = self._settings.get(["access_code"]) access_code = self._settings.get(["access_code"])
ftp = IoTFTPSClient(str(host), 990, "bblp", str(access_code), ssl_implicit=True) ftp = IoTFTPSClient(f"{host}", 990, "bblp", f"{access_code}", ssl_implicit=True)
return ftp return ftp
def _get_file_data(self, file_path: str) -> FileInfo | None: def _get_file_data(self, file_path: str) -> FileInfo | None:

View File

@ -124,7 +124,8 @@ def ftps_session_mock(files_info_ftp):
@fixture(scope="function") @fixture(scope="function")
def print_job_mock(): def print_job_mock():
print_job = MagicMock() print_job = MagicMock()
print_job.get.side_effect = DictGetter({"subtask_name": "", "print_percentage": 0}) print_job.subtask_name = ""
print_job.print_percentage = 0
return print_job return print_job
@ -162,7 +163,7 @@ def printer(
async def _mock_connection(self): async def _mock_connection(self):
pass pass
BambuVirtualPrinter._create_connection_async = _mock_connection BambuVirtualPrinter._create_client_connection_async = _mock_connection
serial_obj = BambuVirtualPrinter( serial_obj = BambuVirtualPrinter(
settings, settings,
profile_manager, profile_manager,
@ -224,9 +225,7 @@ def test_print_started_with_selected_file(printer: BambuVirtualPrinter, print_jo
assert printer.file_system.selected_file is not None assert printer.file_system.selected_file is not None
assert printer.file_system.selected_file.file_name == "print.3mf" assert printer.file_system.selected_file.file_name == "print.3mf"
print_job_mock.get.side_effect = DictGetter( print_job_mock.subtask_name = "print.3mf"
{"subtask_name": "print.3mf", "print_percentage": 0}
)
printer.write(b"M24\n") printer.write(b"M24\n")
printer.flush() printer.flush()
@ -237,9 +236,7 @@ def test_print_started_with_selected_file(printer: BambuVirtualPrinter, print_jo
def test_pause_print(printer: BambuVirtualPrinter, bambu_client_mock, print_job_mock): def test_pause_print(printer: BambuVirtualPrinter, bambu_client_mock, print_job_mock):
print_job_mock.get.side_effect = DictGetter( print_job_mock.subtask_name = "print.3mf"
{"subtask_name": "print.3mf", "print_percentage": 0}
)
printer.write(b"M20\n") printer.write(b"M20\n")
printer.write(b"M23 print.3mf\n") printer.write(b"M23 print.3mf\n")