From 9dd690c5f404e467ce646109b74a384f3c8a8365 Mon Sep 17 00:00:00 2001 From: Marvin Martinson Date: Sun, 2 Dec 2018 21:56:27 +0200 Subject: [PATCH] Push --- .gitignore | 6 ++ LICENSE | 21 ++++ Makefile | 35 +++++++ README.md | 27 ++++++ boot.py | 61 ++++++++++++ config.json | 25 +++++ debounce.py | 57 +++++++++++ hal.py | 257 +++++++++++++++++++++++++++++++++++++++++++++++++ main.py | 179 ++++++++++++++++++++++++++++++++++ uwebsockets.py | 241 ++++++++++++++++++++++++++++++++++++++++++++++ 10 files changed, 909 insertions(+) create mode 100644 .gitignore create mode 100644 LICENSE create mode 100644 Makefile create mode 100644 README.md create mode 100644 boot.py create mode 100644 config.json create mode 100644 debounce.py create mode 100644 hal.py create mode 100644 main.py create mode 100644 uwebsockets.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..61efa50 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +# ignore backup files +._* +# ignore the ESP32 MicroPython binary +esp32*.bin +# ignore the development config file +config-dev.json diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..66cefce --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2017 RoboKoding LTD + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..5af2a5f --- /dev/null +++ b/Makefile @@ -0,0 +1,35 @@ +#!/bin/bash + +SERIAL_PORT=/dev/ttyUSB0 +#SERIAL_PORT=/dev/tty.SLAB_USBtoUART +#SERIAL_PORT=/dev/tty.wchusbserial1410 + +all: flash delay libs config update reset +upload: config update reset console + +delay: + sleep 3 + +reset: + esptool.py -p $(SERIAL_PORT) --after hard_reset read_mac + picocom -b 115200 $(SERIAL_PORT) + +libs: + ampy -p $(SERIAL_PORT) put uwebsockets.py + ampy -p $(SERIAL_PORT) put debounce.py + +update: + ampy -p $(SERIAL_PORT) put hal.py + ampy -p $(SERIAL_PORT) put main.py + ampy -p $(SERIAL_PORT) put boot.py + +config: + ampy -p $(SERIAL_PORT) put config.json + +flash: + esptool.py -p $(SERIAL_PORT) --chip esp32 -b 115200 erase_flash + esptool.py -p $(SERIAL_PORT) --chip esp32 -b 115200 write_flash --flash_mode dio 0x1000 esp32-*.bin + +console: + echo "Ctrl-A + Ctrl-Q to close Picocom" + picocom -b 115200 $(SERIAL_PORT) diff --git a/README.md b/README.md new file mode 100644 index 0000000..ea12f11 --- /dev/null +++ b/README.md @@ -0,0 +1,27 @@ +# sumorobot-firmware + +The software that is running on the SumoRobots + +Code + +# Instructions +* Change the SERIAL_PORT in the Makefile +* Add your WiFi networks to the config.json file +* Install [Python](https://www.python.org/downloads/) +* Install [esptool](https://github.com/espressif/esptool) (to flash MicroPython to the ESP32) +* Install [ampy](https://github.com/adafruit/ampy) (for uploading files) +* Download [the MicroPython binary](http://micropython.org/download#esp32) to this directory +* Upload the MicroPython binary and the SumoRobot firmware to your ESP32 (open a terminal and type: make all) + +# Support +If you find our work useful, please consider donating : ) +[![Donate using Liberapay](https://liberapay.com/assets/widgets/donate.svg)](https://liberapay.com/robokoding/donate) + + +# TODOS +* variable motor speed control, then more interesting for kids + + + +# Credits +* [K-SPACE MTÜ](https://k-space.ee/) diff --git a/boot.py b/boot.py new file mode 100644 index 0000000..c8f417b --- /dev/null +++ b/boot.py @@ -0,0 +1,61 @@ +import ujson +import network +from hal import * +from utime import sleep_ms +from machine import Timer, Pin +import ubinascii + +print("Press Ctrl-C to stop boot script...") +sleep_ms(200) + +#Pin(25, Pin.OUT).value(0) + + +# Open and parse the config file +with open("config.json", "r") as config_file: + config = ujson.load(config_file) + +# if not config["sumo_id"]: +# config["sumo_id"] = ubinascii.hexlify(network.WLAN().config('mac'),':').decode().replace(":","")[6:] +# with open("config.part", "w") as config_file: +# config_file.write(ujson.dumps(config)) +# os.rename("config.part", "config.json") + +config["sumo_id"] = ubinascii.hexlify(network.WLAN().config('mac'),':').decode().replace(":","")[6:] + +sleep_ms(500) + +robotName = "Sumo-"+config["sumo_id"] +# Initialize the SumoRobot object +sumorobot = Sumorobot(config) + +# Indiacte booting with blinking status LED +#timer = Timer(0) +#sumorobot.toggle_led() +#timer.init(period=2000, mode=Timer.PERIODIC, callback=sumorobot.toggle_led) + +# Connect to WiFi +wlan = network.WLAN(network.STA_IF) + +# Activate the WiFi interface +wlan.active(True) + +wlan.config(dhcp_hostname=robotName) + +# If not already connected +if not wlan.isconnected(): + # Scan for WiFi networks + networks = wlan.scan() + # Go trough all scanned WiFi networks + for network in networks: + # Extract the networks SSID + ssid = network[0].decode("utf-8") + # Check if the SSID is in the config file + if ssid in config["wifis"].keys(): + # Start to connect to the pre-configured network + wlan.connect(ssid, config["wifis"][ssid]) + break + +# Clean up +import gc +gc.collect() diff --git a/config.json b/config.json new file mode 100644 index 0000000..fc19260 --- /dev/null +++ b/config.json @@ -0,0 +1,25 @@ +{ + "paly/stop_button_pin": 22, + "charging_pin":4, + "left_motor_pin":13, + "right_motor_pin":15, + "led_power_pin": 25, + "ultrasonic_echo_pin":14, + "ultrasonic_trigger_pin":12, + "line_left_pin":32, + "line_middle_pin":35, + "line_right_pin":34, + "sumo_id": "", + "firmware_version": 0.3, + "left_servo_tuning": 33, + "right_servo_tuning": 33, + "ultrasonic_distance": 40, + "left_line_threshold": 4500, + "middle_line_threshold": 4500, + "right_line_threshold": 4500, + "motors_reverse": 1, + "sumo_server": "sumo.koodur.com:80", + "wifis": { + "SSID": "PAROOL" + } +} diff --git a/debounce.py b/debounce.py new file mode 100644 index 0000000..60d65c1 --- /dev/null +++ b/debounce.py @@ -0,0 +1,57 @@ +# +# inspired by: https://forum.micropython.org/viewtopic.php?t=1938#p10931 +# +import micropython + +try: + from machine import Timer + timer_init = lambda t, p, cb: t.init(period=p, callback=cb) +except ImportError: + from pyb import Timer + timer_init = lambda t, p, cb: t.init(freq=1000 // p, callback=cb) + +# uncomment when debugging callback problems +#micropython.alloc_emergency_exception_buf(100) + + +class DebouncedSwitch: + def __init__(self, sw, cb, arg=None, delay=50, tid=4): + self.sw = sw + # Create references to bound methods beforehand + # http://docs.micropython.org/en/latest/pyboard/library/micropython.html#micropython.schedule + self._sw_cb = self.sw_cb + self._tim_cb = self.tim_cb + self._set_cb = getattr(self.sw, 'callback', None) or self.sw.irq + self.delay = delay + self.tim = Timer(tid) + self.callback(cb, arg) + + def sw_cb(self, pin=None): + self._set_cb(None) + timer_init(self.tim, self.delay, self._tim_cb) + + def tim_cb(self, tim): + tim.deinit() + if self.sw(): + micropython.schedule(self.cb, self.arg) + self._set_cb(self._sw_cb if self.cb else None) + + def callback(self, cb, arg=None): + self.tim.deinit() + self.cb = cb + self.arg = arg + self._set_cb(self._sw_cb if cb else None) + + +def test_pyb(ledno=1): + import pyb + sw = pyb.Switch() + led = pyb.LED(ledno) + return DebouncedSwitch(sw, lambda l: l.toggle(), led) + + +def test_machine(swpin=2, ledpin=16): + from machine import Pin + sw = Pin(swpin, Pin.IN) + led = Pin(ledpin, Pin.OUT) + return DebouncedSwitch(sw, lambda l: l.value(not l.value()), led) diff --git a/hal.py b/hal.py new file mode 100644 index 0000000..7e4f36b --- /dev/null +++ b/hal.py @@ -0,0 +1,257 @@ +import os +import ujson +from utime import sleep_us, sleep_ms +from machine import Pin, PWM, ADC, time_pulse_us, deepsleep +import random + +# LEDs +STATUS = 0 +#OPPONENT = 1 +#LEFT_LINE = 2 +#RIGHT_LINE = 3 + +# Directions +STOP = 0 +LEFT = 1 +MIDDLE = 6 +RIGHT = 2 +SEARCH = 3 +FORWARD = 4 +BACKWARD = 5 + + +#states +MOVING = 0 +STANDBY = 1 + +class Sumorobot(object): + # Constructor + def __init__(self, config = None): + # Config file + self.config = config + + self.state = STANDBY + + self.name = "Sumo-"+self.config["sumo_id"] + + # Ultrasonic distance sensor + self.echo = Pin(self.config["ultrasonic_echo_pin"], Pin.IN) + self.trigger = Pin(self.config["ultrasonic_trigger_pin"], Pin.OUT) + + # Servo PWM-s + self.pwm_left = PWM(Pin(self.config["left_motor_pin"]), freq=50, duty=0) + self.pwm_right = PWM(Pin(self.config["right_motor_pin"]), freq=50, duty=0) + + # Bottom status LED + self.led_power = Pin(self.config["led_power_pin"], Pin.OUT) + self.charging = Pin(self.config["charging_pin"]); + + self.playStop = Pin(self.config["paly/stop_button_pin"], Pin.IN, Pin.PULL_UP) + + + self.led_power.value(0) + + self.adc_line_left = ADC(Pin(32)) + self.adc_line_middle = ADC(Pin(35)) + self.adc_line_right = ADC(Pin(34)) + + # Set reference voltage to 3.3V + self.adc_line_left.atten(ADC.ATTN_11DB) + self.adc_line_right.atten(ADC.ATTN_11DB) + self.adc_line_middle.atten(ADC.ATTN_11DB) + + # To smooth out ultrasonic sensor value + self.opponent_score = 0 + + # For terminating sleep + self.terminate = False + + # For search mode + self.search = 0 + self.search_counter = 0 + + # Memorise previous servo speeds + self.prev_speed = {LEFT: 0, RIGHT: 0} + + #saving line sensor valus, to read once in 50ms loop + self.line_left = 0 + self.line_right = 0; + self.line_middle = 0; + + self.speedForward = 100 if self.config["motors_reverse"] == 0 else -100 + self.speedReverse = -100 if self.config["motors_reverse"] == 0 else 100 + + + # Function to get distance (cm) from the object in front of the SumoRobot + def get_opponent_distance(self): + # Send a pulse + self.trigger.value(0) + sleep_us(5) + self.trigger.value(1) + sleep_us(10) + self.trigger.value(0) + # Wait for the pulse and calculate the distance + return (time_pulse_us(self.echo, 1, 30000) / 2) / 29.1 + + # Function to get boolean if there is something in front of the SumoRobot + def is_opponent(self): + # Get the opponent distance + self.opponent_distance = self.get_opponent_distance() + # When the opponent is close and the ping actually returned + if self.opponent_distance < self.config["ultrasonic_distance"] and self.opponent_distance > 0: + # When not maximum score + if self.opponent_score < 5: + # Increase the opponent score + self.opponent_score += 1 + # When no opponent was detected + else: + # When not lowest score + if self.opponent_score > 0: + # Decrease the opponent score + self.opponent_score -= 1 + + # When the sensor saw something more than 2 times + opponent = True if self.opponent_score > 2 else False + + # Trigger opponent LED + #self.set_led(OPPONENT, opponent) + + return opponent + + # Function to update line calibration and write it to the config file + def calibrate_line(self): + # Read the line sensor values + self.config["left_line_threshold"] = self.adc_line_left.read() + self.config["right_line_threshold"] = self.adc_line_right.read() + self.config["middle_line_threshold"] = self.adc_line_middle.read() + # Update the config file + with open("config.part", "w") as config_file: + config_file.write(ujson.dumps(config)) + os.rename("config.part", "config.json") + + # Function to get light inensity from the phototransistors + def get_line(self, dir): + # Check if the direction is valid + assert dir in (LEFT, RIGHT, MIDDLE) + + # Return the given line sensor value + if dir == LEFT: + return self.adc_line_left.read() + elif dir == RIGHT: + return self.adc_line_right.read() + elif dir == MIDDLE: + return self.adc_line_middle.read() + + def is_line(self, dir): + # Check if the direction is valid + assert dir in (LEFT, RIGHT, MIDDLE) + + # Return the given line sensor value, storing it to variable, not to ask double in 50ms time loop + if dir == LEFT: + self.line_left = self.get_line(LEFT) + line = abs(self.line_left - self.config["left_line_threshold"]) > 1000 + #self.set_led(LEFT_LINE, line) + return line + elif dir == RIGHT: + self.line_right = self.get_line(RIGHT) + line = abs(self.line_right - self.config["right_line_threshold"]) > 1000 + #self.set_led(RIGHT_LINE, line) + return line + elif dir == MIDDLE: + self.line_middle = self.get_line(MIDDLE) + line = abs(self.line_middle - self.config["middle_line_threshold"]) > 1000 + return line + + def set_servo(self, dir, speed): + # Check if the direction is valid + assert dir in (LEFT, RIGHT) + # Check if the speed is valid + assert speed <= 100 and speed >= -100 + + # When the speed didn't change + if speed == self.prev_speed[dir]: + return + + # Record the new speed + self.prev_speed[dir] = speed + + # Set the given servo speed + if dir == LEFT: + if speed == 0: + self.pwm_left.duty(0) + else: + # -100 ... 100 to 33 .. 102 + self.pwm_left.duty(int(33 + self.config["left_servo_tuning"] + speed * 33 / 100)) + elif dir == RIGHT: + if speed == 0: + self.pwm_right.duty(0) + else: + # -100 ... 100 to 33 .. 102 + self.pwm_right.duty(int(33 + self.config["right_servo_tuning"] + speed * 33 / 100)) + + def move(self, dir): + # Check if the direction is valid + assert dir in (SEARCH, STOP, RIGHT, LEFT, BACKWARD, FORWARD) + # Go to the given direction + + if dir == STOP: + self.set_state(STANDBY) + else: + self.set_state(MOVING) + + if dir == STOP: + self.set_servo(LEFT, 0) + self.set_servo(RIGHT, 0) + elif dir == LEFT: + self.set_servo(LEFT, self.speedReverse) + self.set_servo(RIGHT, self.speedReverse) + elif dir == RIGHT: + self.set_servo(LEFT, self.speedForward) + self.set_servo(RIGHT, self.speedForward) + elif dir == SEARCH: + # Change search mode after X seconds + if self.search_counter == 50: + self.search = random.randrange(0,3) + self.search_counter = 0 + #self.search = 0 if self.search > 2 else self.search + # When in search mode + if self.search == 0: + # Go forward + self.set_servo(LEFT, self.speedForward) + self.set_servo(RIGHT, self.speedReverse) + elif self.search == 1: + # Go left + self.set_servo(LEFT, self.speedReverse) + self.set_servo(RIGHT, self.speedReverse) + elif self.search == 2: + self.set_servo(LEFT, self.speedForward) + self.set_servo(RIGHT, self.speedForward) + # Increase search counter + self.search_counter += 1 + elif dir == FORWARD: + self.set_servo(LEFT, self.speedForward) + self.set_servo(RIGHT, self.speedReverse) + elif dir == BACKWARD: + self.set_servo(LEFT, self.speedReverse) + self.set_servo(RIGHT, self.speedForward) + + + def sleep(self, delay): + # Check for valid delay + assert delay > 0 + + # Split the delay into 50ms chunks + for j in range(0, delay, 50): + # Check for forceful termination + if self.terminate: + # Terminate the delay + return + else: + sleep_ms(50) + + def set_state(self,value): + assert value in (MOVING,STANDBY) + self.state = value + + def get_state(self): + return self.state diff --git a/main.py b/main.py new file mode 100644 index 0000000..a82e7e7 --- /dev/null +++ b/main.py @@ -0,0 +1,179 @@ +import _thread +import ubinascii +import ujson +import uwebsockets +import os +from debounce import DebouncedSwitch + +# Code to execute +ast = "" +executeCode = False + +# Scope, info to be sent to the client +scope = dict() + +def buttoncallback(p=True): + global executeCode + + sumorobot.terminate = executeCode + sumorobot.led_power.value(not executeCode) + sleep_ms(50) + executeCode = not executeCode + print(executeCode) + +def writeCodeTofile(data): + with open("code.part", "w") as code_file: + code_file.write(ujson.dumps(data)) + sleep_ms(50) + os.rename("code.part", "code") + sleep_ms(50) + + +def step(): + global scope + + while True: + + # Update scope + scope = dict( + line_left = sumorobot.line_left, + line_right = sumorobot.line_right, + line_middle = sumorobot.line_middle, + opponent = sumorobot.get_opponent_distance(), + battery_voltage = 0, + state = executeCode + ) + + # Execute code + if(executeCode): + exec(ast) + + #sumorobot.playStop.irq(trigger=Pin.IRQ_RISING, handler=buttoncallback) + sw = DebouncedSwitch(sumorobot.playStop, buttoncallback, "dummy") + + # When robot was stopped + if sumorobot.terminate: + # Disable forceful termination of delays in code + sumorobot.terminate = False + # Stop the robot + sumorobot.move(STOP) + # Leave time to process WebSocket commands + sleep_ms(50) + +def ws_handler(): + global executeCode + global ast + global has_wifi_connection + + while True: + # When WiFi has just been reconnected + if wlan.isconnected() and not has_wifi_connection: + #conn = uwebsockets.connect(url) + #sumorobot.set_led(STATUS, True) + has_wifi_connection = True + # When WiFi has just been disconnected + elif not wlan.isconnected() and has_wifi_connection: + #sumorobot.set_led(STATUS, False) + has_wifi_connection = False + elif not wlan.isconnected(): + # Continue to wait for a WiFi connection + continue + + try: # Try to read from the WebSocket + data = conn.recv() + except Exception as e: # Socket timeout, no data received + # Continue to try to read data + #print(e) + + #if not conn.open: + conn = uwebsockets.connect(url) + continue + + # When an empty frame was received + if not data: + # Continue to receive data + continue + elif b'forward' in data: + ast = "" + sumorobot.move(FORWARD) + elif b'backward' in data: + ast = "" + sumorobot.move(BACKWARD) + elif b'right' in data: + ast = "" + sumorobot.move(RIGHT) + elif b'left' in data: + ast = "" + sumorobot.move(LEFT) + elif b'ping' in data: + #conn.send(ujson.dumps({"cmd":"sensord", "data":ujson.dumps(scope)})) + conn.send(ujson.dumps({"cmd":"sensors", "data":ujson.loads(repr(scope).replace("'", '"').replace("False","false").replace("True","true"))})) + elif b'code' in data: + executeCode = False + try: + data = ujson.loads(data) + writeCodeTofile(data['val']) + conn.send(ujson.dumps({"cmd":"code_upload", "status":True})) + except Exception as e: + conn.send(ujson.dumps({"cmd":"code_upload", "status":False})) + continue + data['val'] = data['val'].replace(";;", "\n") + print(data['val']) + ast = compile(data['val'], "snippet", "exec") + elif b'start' in data: + #buttoncallback() + executeCode = True + sumorobot.led_power.value(1) + sleep_ms(50) + data = ujson.loads(data) + writeCodeTofile(data['val']) + data['val'] = data['val'].replace(";;", "\n") + ast = compile(data['val'], "snippet", "exec") + elif b'stop' in data: + #ast = "" + sumorobot.led_power.value(0) + executeCode = False + sumorobot.move(STOP) + # for terminating delays in code + sumorobot.terminate = True + elif b'calibrate_line' in data: + sumorobot.led_power.value(1) + sleep_ms(50) + sumorobot.calibrate_line() + sumorobot.led_power.value(0) + elif b'Gone' in data: + print("server said 410 Gone, attempting to reconnect...") + #conn = uwebsockets.connect(url) + else: + print("unknown cmd:", data) + +# Wait for WiFi to get connected +while not wlan.isconnected(): + sleep_ms(100) + +# Connect to the websocket +url = "ws://%s/p2p/sumo-%s/browser/" % (config['sumo_server'], config['sumo_id']) +conn = uwebsockets.connect(url) + +# Set X seconds timeout for socket reads +conn.settimeout(3) + +# Stop bootup blinking +#timer.deinit() + +# WiFi is connected +has_wifi_connection = True +# Indicate that the WebSocket is connected +#sumorobot.set_led(STATUS, True) + +if('code' in os.listdir()): + with open("code", "r") as code_file: + data = ujson.load(code_file).replace(";;", "\n") + ast = compile(data, "snippet", "exec") + #print(ast) + + +# Start the code processing thread +_thread.start_new_thread(step, ()) +# Start the Websocket processing thread +_thread.start_new_thread(ws_handler, ()) diff --git a/uwebsockets.py b/uwebsockets.py new file mode 100644 index 0000000..5ce1b3c --- /dev/null +++ b/uwebsockets.py @@ -0,0 +1,241 @@ +""" +Websockets client for micropython + +Based very heavily on +https://github.com/aaugustin/websockets/blob/master/websockets/client.py +""" + +#import usocket as socket +import os +import ure as re +import urandom as random +import ustruct as struct +import usocket as socket +import ubinascii as binascii +from ucollections import namedtuple + +# Opcodes +OP_CONT = const(0x0) +OP_TEXT = const(0x1) +OP_BYTES = const(0x2) +OP_CLOSE = const(0x8) +OP_PING = const(0x9) +OP_PONG = const(0xa) + +# Close codes +CLOSE_OK = const(1000) +CLOSE_GOING_AWAY = const(1001) +CLOSE_PROTOCOL_ERROR = const(1002) +CLOSE_DATA_NOT_SUPPORTED = const(1003) +CLOSE_BAD_DATA = const(1007) +CLOSE_POLICY_VIOLATION = const(1008) +CLOSE_TOO_BIG = const(1009) +CLOSE_MISSING_EXTN = const(1010) +CLOSE_BAD_CONDITION = const(1011) + +URL_RE = re.compile(r'ws://([A-Za-z0-9\-\.]+)(?:\:([0-9]+))?(/.+)?') +URI = namedtuple('URI', ('hostname', 'port', 'path')) + +def urlparse(uri): + match = URL_RE.match(uri) + if match: + return URI(match.group(1), int(match.group(2)), match.group(3)) + else: + raise ValueError("Invalid URL: %s" % uri) + +class Websocket: + is_client = False + + def __init__(self, sock): + self._sock = sock + self.open = True + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + self.close() + + def settimeout(self, timeout): + self._sock.settimeout(timeout) + + def read_frame(self, max_size=None): + # Frame header + byte1, byte2 = struct.unpack('!BB', self._sock.read(2)) + + # Byte 1: FIN(1) _(1) _(1) _(1) OPCODE(4) + fin = bool(byte1 & 0x80) + opcode = byte1 & 0x0f + + # Byte 2: MASK(1) LENGTH(7) + mask = bool(byte2 & (1 << 7)) + length = byte2 & 0x7f + + if length == 126: # Magic number, length header is 2 bytes + length, = struct.unpack('!H', self._sock.read(2)) + elif length == 127: # Magic number, length header is 8 bytes + length, = struct.unpack('!Q', self._sock.read(8)) + + if mask: # Mask is 4 bytes + mask_bits = self._sock.read(4) + + try: + data = self._sock.read(length) + except MemoryError: + # We can't receive this many bytes, close the socket + self.close(code=CLOSE_TOO_BIG) + return True, OP_CLOSE, None + + if mask: + data = bytes(b ^ mask_bits[i % 4] + for i, b in enumerate(data)) + + return fin, opcode, data + + def write_frame(self, opcode, data=b''): + fin = True + mask = self.is_client # messages sent by client are masked + + length = len(data) + + # Frame header + # Byte 1: FIN(1) _(1) _(1) _(1) OPCODE(4) + byte1 = 0x80 if fin else 0 + byte1 |= opcode + + # Byte 2: MASK(1) LENGTH(7) + byte2 = 0x80 if mask else 0 + + if length < 126: # 126 is magic value to use 2-byte length header + byte2 |= length + self._sock.write(struct.pack('!BB', byte1, byte2)) + + elif length < (1 << 16): # Length fits in 2-bytes + byte2 |= 126 # Magic code + self._sock.write(struct.pack('!BBH', byte1, byte2, length)) + + elif length < (1 << 64): + byte2 |= 127 # Magic code + self._sock.write(struct.pack('!BBQ', byte1, byte2, length)) + + else: + raise ValueError() + + if mask: # Mask is 4 bytes + mask_bits = struct.pack('!I', random.getrandbits(32)) + self._sock.write(mask_bits) + + data = bytes(b ^ mask_bits[i % 4] + for i, b in enumerate(data)) + + self._sock.write(data) + + def recv(self): + assert self.open + + while self.open: + try: + fin, opcode, data = self.read_frame() + except ValueError: + self._close() + return + + if not fin: + raise NotImplementedError() + + if opcode == OP_TEXT: + return data + elif opcode == OP_BYTES: + return data + elif opcode == OP_CLOSE: + self._close() + return + elif opcode == OP_PONG: + # Ignore this frame, keep waiting for a data frame + continue + elif opcode == OP_PING: + # We need to send a pong frame + self.write_frame(OP_PONG, data) + # And then wait to receive + continue + elif opcode == OP_CONT: + # This is a continuation of a previous frame + raise NotImplementedError(opcode) + else: + raise ValueError(opcode) + + def send(self, buf): + assert self.open + + if isinstance(buf, str): + opcode = OP_TEXT + buf = buf.encode('utf-8') + elif isinstance(buf, bytes): + opcode = OP_BYTES + else: + raise TypeError() + + self.write_frame(opcode, buf) + + def close(self, code=CLOSE_OK, reason=''): + if not self.open: + return + + buf = struct.pack('!H', code) + reason.encode('utf-8') + + self.write_frame(OP_CLOSE, buf) + self._close() + + def _close(self): + self.open = False + self._sock.close() + +class WebsocketClient(Websocket): + is_client = True + +def connect(uri): + """ + Connect a websocket. + """ + + # Parse the given WebSocket URI + uri = urlparse(uri) + assert uri + + # Connect the socket + sock = socket.socket() + addr = socket.getaddrinfo(uri.hostname, uri.port) + sock.connect(addr[0][4]) + + # Sec-WebSocket-Key is 16 bytes of random base64 encoded + key = binascii.b2a_base64(os.urandom(16))[:-1] + + # WebSocket initiation headers + headers = [ + b'GET %s HTTP/1.1' % uri.path or '/', + b'Upgrade: websocket', + b'Connection: Upgrade', + b'Host: %s:%s' % (uri.hostname, uri.port), + b'Origin: http://%s:%s' % (uri.hostname, uri.port), + b'Sec-WebSocket-Key: ' + key, + b'Sec-WebSocket-Version: 13', + b'', + b'' + ] + + # Concatenate the headers and add new lines + data = b'\r\n'.join(headers) + + # Send the WebSocket initiation packet + sock.send(data) + + # Check for the WebSocket response header + header = sock.readline()[:-2] + assert header == b'HTTP/1.1 101 Switching Protocols', header + + # We don't (currently) need these headers + # FIXME: should we check the return key? + while header: + header = sock.readline()[:-2] + + return WebsocketClient(sock)