Fix bambu client connection.
This commit is contained in:
		| @@ -5,12 +5,10 @@ __license__ = "GNU Affero General Public License http://www.gnu.org/licenses/agp | ||||
| import collections | ||||
| from dataclasses import dataclass, field | ||||
| import math | ||||
| import os | ||||
| import queue | ||||
| import re | ||||
| import threading | ||||
| import time | ||||
| import asyncio | ||||
| from octoprint_bambu_printer.printer.print_job import PrintJob | ||||
| from pybambu import BambuClient, commands | ||||
| import logging | ||||
| @@ -99,13 +97,10 @@ class BambuVirtualPrinter: | ||||
|         self._serial_io.start() | ||||
|         self._printer_thread.start() | ||||
|  | ||||
|         self._bambu_client: BambuClient = None | ||||
|         asyncio.get_event_loop().run_until_complete(self._create_connection_async()) | ||||
|         self._bambu_client: BambuClient = self._create_client_connection_async() | ||||
|  | ||||
|     @property | ||||
|     def bambu_client(self): | ||||
|         if self._bambu_client is None: | ||||
|             raise ValueError("No connection to Bambulab was established") | ||||
|         return self._bambu_client | ||||
|  | ||||
|     @property | ||||
| @@ -133,7 +128,7 @@ class BambuVirtualPrinter: | ||||
|  | ||||
|     def update_print_job_info(self): | ||||
|         print_job_info = self.bambu_client.get_device().print_job | ||||
|         filename: str = print_job_info.get("subtask_name") | ||||
|         filename: str = print_job_info.subtask_name | ||||
|         project_file_info = self.file_system.get_data_by_suffix( | ||||
|             filename, [".3mf", ".gcode.3mf"] | ||||
|         ) | ||||
| @@ -146,7 +141,7 @@ class BambuVirtualPrinter: | ||||
|             self.sendOk() | ||||
|  | ||||
|         # fuzzy math here to get print percentage to match BambuStudio | ||||
|         progress = print_job_info.get("print_percentage") | ||||
|         progress = print_job_info.print_percentage | ||||
|         self._current_print_job = PrintJob(project_file_info, 0) | ||||
|         self._current_print_job.progress = progress | ||||
|  | ||||
| @@ -195,20 +190,27 @@ class BambuVirtualPrinter: | ||||
|         self._log.debug(f"on connect called") | ||||
|         return on_connect | ||||
|  | ||||
|     async def _create_connection_async(self): | ||||
|     def _create_client_connection_async(self): | ||||
|         self._create_client_connection() | ||||
|         if self._bambu_client is None: | ||||
|             raise RuntimeError("Connection with Bambu Client not established") | ||||
|         return self._bambu_client | ||||
|  | ||||
|     def _create_client_connection(self): | ||||
|         if ( | ||||
|             self._settings.get(["device_type"]) == "" | ||||
|             or self._settings.get(["serial"]) == "" | ||||
|             or self._settings.get(["username"]) == "" | ||||
|             or self._settings.get(["access_code"]) == "" | ||||
|         ): | ||||
|             self._log.debug("invalid settings to start connection with Bambu Printer") | ||||
|             return | ||||
|             msg = "invalid settings to start connection with Bambu Printer" | ||||
|             self._log.debug(msg) | ||||
|             raise ValueError(msg) | ||||
|  | ||||
|         self._log.debug( | ||||
|             f"connecting via local mqtt: {self._settings.get_boolean(['local_mqtt'])}" | ||||
|         ) | ||||
|         self._bambu_client = BambuClient( | ||||
|         bambu_client = BambuClient( | ||||
|             device_type=self._settings.get(["device_type"]), | ||||
|             serial=self._settings.get(["serial"]), | ||||
|             host=self._settings.get(["host"]), | ||||
| @@ -223,18 +225,17 @@ class BambuVirtualPrinter: | ||||
|             email=self._settings.get(["email"]), | ||||
|             auth_token=self._settings.get(["auth_token"]), | ||||
|         ) | ||||
|         self._bambu_client.on_disconnect = self.on_disconnect( | ||||
|             self._bambu_client.on_disconnect | ||||
|         ) | ||||
|         self._bambu_client.on_connect = self.on_connect(self._bambu_client.on_connect) | ||||
|         self._bambu_client.connect(callback=self.new_update) | ||||
|         self._log.info(f"bambu connection status: {self._bambu_client.connected}") | ||||
|         bambu_client.on_disconnect = self.on_disconnect(bambu_client.on_disconnect) | ||||
|         bambu_client.on_connect = self.on_connect(bambu_client.on_connect) | ||||
|         bambu_client.connect(callback=self.new_update) | ||||
|         self._log.info(f"bambu connection status: {bambu_client.connected}") | ||||
|         self._serial_io.sendOk() | ||||
|         self._bambu_client = bambu_client | ||||
|  | ||||
|     def __str__(self): | ||||
|         return "BAMBU(read_timeout={read_timeout},write_timeout={write_timeout},options={options})".format( | ||||
|             read_timeout=self._read_timeout, | ||||
|             write_timeout=self._write_timeout, | ||||
|             read_timeout=self.timeout, | ||||
|             write_timeout=self.write_timeout, | ||||
|             options={ | ||||
|                 "device_type": self._settings.get(["device_type"]), | ||||
|                 "host": self._settings.get(["host"]), | ||||
| @@ -258,21 +259,21 @@ class BambuVirtualPrinter: | ||||
|  | ||||
|     @property | ||||
|     def timeout(self): | ||||
|         return self._read_timeout | ||||
|         return self._serial_io._read_timeout | ||||
|  | ||||
|     @timeout.setter | ||||
|     def timeout(self, value): | ||||
|         self._log.debug(f"Setting read timeout to {value}s") | ||||
|         self._read_timeout = value | ||||
|         self._serial_io._read_timeout = value | ||||
|  | ||||
|     @property | ||||
|     def write_timeout(self): | ||||
|         return self._write_timeout | ||||
|         return self._serial_io._write_timeout | ||||
|  | ||||
|     @write_timeout.setter | ||||
|     def write_timeout(self, value): | ||||
|         self._log.debug(f"Setting write timeout to {value}s") | ||||
|         self._write_timeout = value | ||||
|         self._serial_io._write_timeout = value | ||||
|  | ||||
|     @property | ||||
|     def port(self): | ||||
| @@ -505,6 +506,8 @@ class BambuVirtualPrinter: | ||||
|         self._state_change_queue.join() | ||||
|  | ||||
|     def _printer_worker(self): | ||||
|         self._create_client_connection_async() | ||||
|         self.sendIO("Printer connection complete") | ||||
|         while self._running: | ||||
|             try: | ||||
|                 next_state = self._state_change_queue.get(timeout=0.01) | ||||
|   | ||||
| @@ -6,6 +6,7 @@ import queue | ||||
| import re | ||||
| import threading | ||||
| import time | ||||
| import traceback | ||||
| from typing import Callable | ||||
|  | ||||
| from octoprint.util import to_bytes, to_unicode | ||||
| @@ -35,6 +36,7 @@ class PrinterSerialIO(threading.Thread, BufferedIOBase): | ||||
|         self._read_timeout = read_timeout | ||||
|         self._write_timeout = write_timeout | ||||
|  | ||||
|         self.current_line = 0 | ||||
|         self._received_lines = 0 | ||||
|         self._wait_interval = 5.0 | ||||
|         self._running = True | ||||
| @@ -42,7 +44,7 @@ class PrinterSerialIO(threading.Thread, BufferedIOBase): | ||||
|         self._rx_buffer_size = 64 | ||||
|         self._incoming_lock = threading.RLock() | ||||
|  | ||||
|         self.input_bytes = CharCountingQueue(self._rx_buffer_size, name="RxBuffer") | ||||
|         self.input_bytes = queue.Queue(self._rx_buffer_size) | ||||
|         self.output_bytes = queue.Queue() | ||||
|         self._error_detected: Exception | None = None | ||||
|  | ||||
| @@ -77,8 +79,8 @@ class PrinterSerialIO(threading.Thread, BufferedIOBase): | ||||
|             except Exception as e: | ||||
|                 self._error_detected = e | ||||
|                 self.input_bytes.task_done() | ||||
|                 self.input_bytes.clear() | ||||
|                 break | ||||
|                 self._clearQueue(self.input_bytes) | ||||
|                 self._log.info("\n".join(traceback.format_exception(e)[-50:])) | ||||
|  | ||||
|         self._log.debug("Closing IO read loop") | ||||
|  | ||||
| @@ -98,6 +100,9 @@ class PrinterSerialIO(threading.Thread, BufferedIOBase): | ||||
|  | ||||
|     def flush(self): | ||||
|         self.input_bytes.join() | ||||
|         self.raise_if_error() | ||||
|  | ||||
|     def raise_if_error(self): | ||||
|         if self._error_detected is not None: | ||||
|             raise self._error_detected | ||||
|  | ||||
| @@ -110,11 +115,9 @@ class PrinterSerialIO(threading.Thread, BufferedIOBase): | ||||
|                 return 0 | ||||
|  | ||||
|             try: | ||||
|                 written = self.input_bytes.put( | ||||
|                     data, timeout=self._write_timeout, partial=True | ||||
|                 ) | ||||
|                 self.input_bytes.put(data, timeout=self._write_timeout) | ||||
|                 self._log.debug(f"<<< {u_data}") | ||||
|                 return written | ||||
|                 return len(data) | ||||
|             except queue.Full: | ||||
|                 self._log.error( | ||||
|                     "Incoming queue is full, raising SerialTimeoutException" | ||||
| @@ -169,29 +172,11 @@ class PrinterSerialIO(threading.Thread, BufferedIOBase): | ||||
|             self.send(self._format_error("checksum_missing")) | ||||
|             return | ||||
|  | ||||
|         # track N = N + 1 | ||||
|         linenumber = 0 | ||||
|         if data.startswith(b"N") and b"M110" in data: | ||||
|             linenumber = int(re.search(b"N([0-9]+)", data).group(1)) | ||||
|             self.lastN = linenumber | ||||
|             self.current_line = linenumber | ||||
|             self.sendOk() | ||||
|         line = self._process_linenumber_marker(data) | ||||
|         if line is None: | ||||
|             return | ||||
|         elif data.startswith(b"N"): | ||||
|             linenumber = int(re.search(b"N([0-9]+)", data).group(1)) | ||||
|             expected = self.lastN + 1 | ||||
|             if linenumber != expected: | ||||
|                 self._triggerResend(actual=linenumber) | ||||
|                 return | ||||
|             else: | ||||
|                 self.lastN = linenumber | ||||
|  | ||||
|             data = data.split(None, 1)[1].strip() | ||||
|  | ||||
|         data += b"\n" | ||||
|  | ||||
|         command = to_unicode(data, encoding="ascii", errors="replace").strip() | ||||
|  | ||||
|         command = to_unicode(line, encoding="ascii", errors="replace").strip() | ||||
|         command_match = self.command_regex.match(command) | ||||
|         if command_match is not None: | ||||
|             gcode = command_match.group(0) | ||||
| @@ -199,6 +184,25 @@ class PrinterSerialIO(threading.Thread, BufferedIOBase): | ||||
|         else: | ||||
|             self._log.warn(f'Not a valid gcode command "{command}"') | ||||
|  | ||||
|     def _process_linenumber_marker(self, data: bytes): | ||||
|         linenumber = 0 | ||||
|         if data.startswith(b"N") and b"M110" in data: | ||||
|             linenumber = int(re.search(b"N([0-9]+)", data).group(1)) | ||||
|             self.lastN = linenumber | ||||
|             self.current_line = linenumber | ||||
|             self.sendOk() | ||||
|             return None | ||||
|         elif data.startswith(b"N"): | ||||
|             linenumber = int(re.search(b"N([0-9]+)", data).group(1)) | ||||
|             expected = self.lastN + 1 | ||||
|             if linenumber != expected: | ||||
|                 self._triggerResend(actual=linenumber) | ||||
|                 return None | ||||
|             else: | ||||
|                 self.lastN = linenumber | ||||
|             data = data.split(None, 1)[1].strip() | ||||
|         return data | ||||
|  | ||||
|     def _triggerResend( | ||||
|         self, | ||||
|         expected: int | None = None, | ||||
|   | ||||
| @@ -90,7 +90,7 @@ class RemoteSDCardFileList: | ||||
|     def _connect_ftps_server(self): | ||||
|         host = self._settings.get(["host"]) | ||||
|         access_code = self._settings.get(["access_code"]) | ||||
|         ftp = IoTFTPSClient(str(host), 990, "bblp", str(access_code), ssl_implicit=True) | ||||
|         ftp = IoTFTPSClient(f"{host}", 990, "bblp", f"{access_code}", ssl_implicit=True) | ||||
|         return ftp | ||||
|  | ||||
|     def _get_file_data(self, file_path: str) -> FileInfo | None: | ||||
|   | ||||
| @@ -124,7 +124,8 @@ def ftps_session_mock(files_info_ftp): | ||||
| @fixture(scope="function") | ||||
| def print_job_mock(): | ||||
|     print_job = MagicMock() | ||||
|     print_job.get.side_effect = DictGetter({"subtask_name": "", "print_percentage": 0}) | ||||
|     print_job.subtask_name = "" | ||||
|     print_job.print_percentage = 0 | ||||
|     return print_job | ||||
|  | ||||
|  | ||||
| @@ -162,7 +163,7 @@ def printer( | ||||
|     async def _mock_connection(self): | ||||
|         pass | ||||
|  | ||||
|     BambuVirtualPrinter._create_connection_async = _mock_connection | ||||
|     BambuVirtualPrinter._create_client_connection_async = _mock_connection | ||||
|     serial_obj = BambuVirtualPrinter( | ||||
|         settings, | ||||
|         profile_manager, | ||||
| @@ -224,9 +225,7 @@ def test_print_started_with_selected_file(printer: BambuVirtualPrinter, print_jo | ||||
|     assert printer.file_system.selected_file is not None | ||||
|     assert printer.file_system.selected_file.file_name == "print.3mf" | ||||
|  | ||||
|     print_job_mock.get.side_effect = DictGetter( | ||||
|         {"subtask_name": "print.3mf", "print_percentage": 0} | ||||
|     ) | ||||
|     print_job_mock.subtask_name = "print.3mf" | ||||
|  | ||||
|     printer.write(b"M24\n") | ||||
|     printer.flush() | ||||
| @@ -237,9 +236,7 @@ def test_print_started_with_selected_file(printer: BambuVirtualPrinter, print_jo | ||||
|  | ||||
|  | ||||
| def test_pause_print(printer: BambuVirtualPrinter, bambu_client_mock, print_job_mock): | ||||
|     print_job_mock.get.side_effect = DictGetter( | ||||
|         {"subtask_name": "print.3mf", "print_percentage": 0} | ||||
|     ) | ||||
|     print_job_mock.subtask_name = "print.3mf" | ||||
|  | ||||
|     printer.write(b"M20\n") | ||||
|     printer.write(b"M23 print.3mf\n") | ||||
|   | ||||
		Reference in New Issue
	
	Block a user