From 0b9cb03bae1a1436eaa0966a870a4032d8677d19 Mon Sep 17 00:00:00 2001 From: Michael Chen Date: Sun, 16 Jan 2022 21:31:07 +0100 Subject: [PATCH] Implemented auto updating lua script Downloads latest script from server if outdated (10 seconds) Server sends encrypted token to client to keep session new and rejects ..old tokens This allows updating the script in this repository --- MinecraftDiscordBot/AesCipher.cs | 44 ++++++ MinecraftDiscordBot/ClientScript.lua | 137 ++++++++++++++++++ MinecraftDiscordBot/ConnectedComputer.cs | 29 +++- MinecraftDiscordBot/Message.cs | 6 +- .../MinecraftDiscordBot.csproj | 8 + MinecraftDiscordBot/Program.cs | 46 +++++- MinecraftDiscordBot/TimeoutTokenProvider.cs | 32 ++++ 7 files changed, 290 insertions(+), 12 deletions(-) create mode 100644 MinecraftDiscordBot/AesCipher.cs create mode 100644 MinecraftDiscordBot/ClientScript.lua create mode 100644 MinecraftDiscordBot/TimeoutTokenProvider.cs diff --git a/MinecraftDiscordBot/AesCipher.cs b/MinecraftDiscordBot/AesCipher.cs new file mode 100644 index 0000000..3216387 --- /dev/null +++ b/MinecraftDiscordBot/AesCipher.cs @@ -0,0 +1,44 @@ +using System.Security.Cryptography; + +namespace MinecraftDiscordBot; + +public class AesCipher : ICipher { + private readonly byte[] key; + private readonly byte[] iv; + + public AesCipher() { + using var aes = Aes.Create(); + aes.GenerateKey(); + aes.GenerateIV(); + key = aes.Key; + iv = aes.IV; + } + + public byte[] Encrypt(byte[] plain) { + using var aes = Aes.Create(); + aes.Key = key; + aes.IV = iv; + var transformer = aes.CreateEncryptor(); + using var ms = new MemoryStream(); + using (var cs = new CryptoStream(ms, transformer, CryptoStreamMode.Write)) + cs.Write(plain); + return ms.ToArray(); + } + + public byte[] Decrypt(byte[] cipher) { + using Aes aes = Aes.Create(); + aes.Key = key; + aes.IV = iv; + var transformer = aes.CreateDecryptor(); + using MemoryStream ms = new MemoryStream(cipher); + using CryptoStream cs = new CryptoStream(ms, transformer, CryptoStreamMode.Read); + using MemoryStream os = new MemoryStream(); + cs.CopyTo(os); + return os.ToArray(); + } +} + +public interface ICipher { + byte[] Decrypt(byte[] cipher); + byte[] Encrypt(byte[] plain); +} \ No newline at end of file diff --git a/MinecraftDiscordBot/ClientScript.lua b/MinecraftDiscordBot/ClientScript.lua new file mode 100644 index 0000000..04dea44 --- /dev/null +++ b/MinecraftDiscordBot/ClientScript.lua @@ -0,0 +1,137 @@ +local secretToken = "$TOKEN" +local connectionUri = "ws://ws.cnml.de:8081" +local waitSeconds = 5 + +local function chunkString(value, chunkSize) + if not chunkSize then chunkSize = 10000 end + local length = value:len() + local total = math.ceil(length / chunkSize) + local chunks = {} + local i = 1 + for i=1,total do + local pos = 1 + ((i - 1) * chunkSize) + chunks[i] = value:sub(pos, pos + chunkSize - 1) + end + return total, chunks +end + +local function sendJson(socket, message) + return socket.send(textutils.serializeJSON(message)) +end + +local function sendResponse(socket, id, result, success) + if success == nil then success = true end + + if not success then + sendJson(socket, { id = id, result = result, success = success }) + return + end + + local total, chunks = chunkString(result) + for i, chunk in pairs(chunks) do + sendJson(socket, { id = id, result = chunk, chunk = i, total = total, success = success }) + end +end + +-- error: no rs system +-- return rssystem rs +local function getPeripheral(name) + local dev = peripheral.find(name) + if not dev then error("No peripheral '"..name.."' attached to the computer!") end + return dev +end + +-- error: any error during execution +-- return string result +local function getResponse(parsed) + if parsed.method == "energyusage" then + return tostring(getPeripheral("rsBridge").getEnergyUsage()) + elseif parsed.method == "energystorage" then + return tostring(getPeripheral("rsBridge").getEnergyStorage()) + elseif parsed.method == "listitems" then + return textutils.serializeJSON(getPeripheral("rsBridge").listItems()) + elseif parsed.method == "listfluids" then + return textutils.serializeJSON(getPeripheral("rsBridge").listFluids()) + elseif parsed.method == "craft" then + return tostring(getPeripheral("rsBridge").craftItem(parsed.params)) + end + + error("No message handler for method: "..parsed.method.."!") +end + +local function logJSON(json, prefix) + if not prefix then prefix = "" end + for k,v in pairs(json) do + local key = prefix..k + if type(v) == "table" then + logJSON(v, key..".") + else + print(key, "=", textutils.serializeJSON(v)) + end + end +end + +-- return bool success +local function handleMessage(socket, message) + local parsed, reason = textutils.unserializeJSON(message) + if not parsed then + print("Received message:", message) + printError("Message could not be parsed:", reason) + return false + end + + pcall(function() print("Received JSON:") logJSON(parsed) end) + + if parsed.type == "request" then + local success, result = pcall(function() return getResponse(parsed) end) + sendResponse(socket, parsed.id, result, success) + return true + end + + printError("Invalid message type:", parsed.type) + return false +end + +local function socketClient() + print("Connecting to the socket server at "..connectionUri.."...") + local socket, reason = http.websocket(connectionUri) + if not socket then error("Socket server could not be reached: "..reason) end + print("Connection successful!") + + socket.send("login="..secretToken) + while true do + local message, binary = socket.receive() + if not not message and not binary then + if message == "outdated" then + printError("Current script is outdated! Please update from the host!") + return + end + handleMessage(socket, message) + end + end +end + +local function termWaiter() + os.pullEvent("terminate") +end + +local function services() + parallel.waitForAny(termWaiter, function() + parallel.waitForAll(socketClient) + end) +end + +local function main() + while true do + local status, error = pcall(services) + if status then break end + printError("An uncaught exception was raised:", error) + printError("Restarting in", waitSeconds, "seconds...") + sleep(waitSeconds) + end +end + +local oldPullEvent = os.pullEvent +os.pullEvent = os.pullEventRaw +pcall(main) +os.pullEvent = oldPullEvent \ No newline at end of file diff --git a/MinecraftDiscordBot/ConnectedComputer.cs b/MinecraftDiscordBot/ConnectedComputer.cs index c98a30f..0abe9c4 100644 --- a/MinecraftDiscordBot/ConnectedComputer.cs +++ b/MinecraftDiscordBot/ConnectedComputer.cs @@ -27,9 +27,9 @@ public class RefinedStorageService : CommandRouter { public override Task RootAnswer(SocketUserMessage message, CancellationToken ct) => Task.FromResult(ResponseType.AsString("The RS system is online!")); - private async Task Method(string methodName, Func parser, CancellationToken ct) { + private async Task Method(string methodName, Func parser, CancellationToken ct, Dictionary? parameters = null) { var waiter = _taskSource.GetWaiter(parser, ct); - await _taskSource.Send(new RequestMessage(waiter.ID, methodName)); + await _taskSource.Send(new RequestMessage(waiter.ID, methodName, parameters)); return await waiter.Task; } @@ -38,11 +38,16 @@ public class RefinedStorageService : CommandRouter { private const string CmdListItems = "listitems"; private const string CmdItemName = "itemname"; private const string CmdListFluids = "listfluids"; + private const string CmdCraftItem = "craft"; public async Task GetEnergyUsageAsync(CancellationToken ct) => await Method(CmdEnergyUsage, int.Parse, ct); public async Task GetEnergyStorageAsync(CancellationToken ct) => await Method(CmdEnergyStorage, int.Parse, ct); public async Task> ListItemsAsync(CancellationToken ct) => await Method(CmdListItems, ConnectedComputer.Deserialize>(), ct); public async Task> ListFluidsAsync(CancellationToken ct) => await Method(CmdListFluids, ConnectedComputer.Deserialize>(), ct); + public async Task CraftItem(string itemid, int amount, CancellationToken ct) => await Method(CmdCraftItem, ConnectedComputer.Deserialize(), ct, new() { + ["name"] = itemid, + ["count"] = amount + }); private Task> FilterItems(SocketUserMessage message, IEnumerable filters, CancellationToken ct) => FilterItems(message, filters.Select(ItemFilter.Parse), ct); @@ -72,6 +77,24 @@ public class RefinedStorageService : CommandRouter { [CommandHandler(CmdEnergyUsage, HelpText = "Get the amount of energy used by the RS system.")] public async Task HandleEnergyUsage(SocketUserMessage message, string[] parameters, CancellationToken ct) => ResponseType.AsString($"Refined Storage system currently uses {await GetEnergyUsageAsync(ct)} RF/t"); + [CommandHandler(CmdCraftItem, HelpText = "Craft a specific item given an item ID and optionally an amount.")] + public async Task HandleCraftItem(SocketUserMessage message, string[] parameters, CancellationToken ct) { + var amount = 1; + string itemid; + if (parameters.Length is 1 or 2) { + itemid = parameters[0]; + if (parameters.Length is 2) + if (int.TryParse(parameters[1], out var value)) amount = value; + else return ResponseType.AsString($"I expected an amount to craft, not '{parameters[1]}'!"); + } else return parameters.Length is < 1 + ? ResponseType.AsString("You have to give me at least an item name!") + : parameters.Length is > 2 + ? ResponseType.AsString("Yo, those are way too many arguments! I want only item name and maybe an amount!") + : throw new InvalidOperationException($"Forgot to match parameter length {parameters.Length}!"); + return await CraftItem(itemid, amount, ct) + ? ResponseType.AsString($"Alright, I'm starting to craft {amount} {itemid}.") + : ResponseType.AsString($"Nope, that somehow doesn't work!"); + } [CommandHandler(CmdItemName, HelpText = "Filter items by name.")] public async Task HandleItemName(SocketUserMessage message, string[] parameters, CancellationToken ct) { if (parameters.Length < 2) return ResponseType.AsString($"Usage: {CmdItemName} filters..."); @@ -165,7 +188,7 @@ public class ConnectedComputer : CommandRouter, ITaskWaitSource { } private readonly ICommandHandler _rs; - [CommandHandler("rs", HelpText ="Provides some commands for interacting with the Refined Storage system.")] + [CommandHandler("rs", HelpText = "Provides some commands for interacting with the Refined Storage system.")] public Task RefinedStorageHandler(SocketUserMessage message, string[] parameters, CancellationToken ct) => _rs.HandleCommand(message, parameters, ct); diff --git a/MinecraftDiscordBot/Message.cs b/MinecraftDiscordBot/Message.cs index d6ab349..2db0dd7 100644 --- a/MinecraftDiscordBot/Message.cs +++ b/MinecraftDiscordBot/Message.cs @@ -49,10 +49,10 @@ public class ReplyMessage : Message { } public class RequestMessage : Message { - public RequestMessage(int answerId, string method, Dictionary? parameters = null) { + public RequestMessage(int answerId, string method, Dictionary? parameters = null) { AnswerId = answerId; Method = method; - Parameters = (parameters ?? Enumerable.Empty>()) + Parameters = (parameters ?? Enumerable.Empty>()) .ToDictionary(i => i.Key, i => i.Value); } [JsonProperty("id")] @@ -60,6 +60,6 @@ public class RequestMessage : Message { [JsonProperty("method")] public string Method { get; set; } [JsonProperty("params")] - public Dictionary Parameters { get; } + public Dictionary Parameters { get; } public override string Type => "request"; } \ No newline at end of file diff --git a/MinecraftDiscordBot/MinecraftDiscordBot.csproj b/MinecraftDiscordBot/MinecraftDiscordBot.csproj index 810ea2d..63bff82 100644 --- a/MinecraftDiscordBot/MinecraftDiscordBot.csproj +++ b/MinecraftDiscordBot/MinecraftDiscordBot.csproj @@ -15,6 +15,10 @@ $(VersionPrefix) + + + + @@ -24,4 +28,8 @@ + + + + diff --git a/MinecraftDiscordBot/Program.cs b/MinecraftDiscordBot/Program.cs index c186ae5..b84fa77 100644 --- a/MinecraftDiscordBot/Program.cs +++ b/MinecraftDiscordBot/Program.cs @@ -30,6 +30,17 @@ public class Program : IDisposable, ICommandHandler { private ConnectedComputer? _rsSystem = null; private bool disposedValue; public static bool OnlineNotifications => false; + public static readonly string ClientScript = GetClientScript(); + private readonly ITokenProvider _tokenProvider = new TimeoutTokenProvider(10); + + private string GetVerifiedClientScript() => ClientScript.Replace("$TOKEN", _tokenProvider.GenerateToken()); + + private static string GetClientScript() { + using var stream = Assembly.GetExecutingAssembly().GetManifestResourceStream("MinecraftDiscordBot.ClientScript.lua"); + if (stream is null) throw new FileNotFoundException("Client script could not be loaded!"); + using var sr = new StreamReader(stream); + return sr.ReadToEnd(); + } public ConnectedComputer? Computer { get => _rsSystem; set { @@ -120,8 +131,34 @@ public class Program : IDisposable, ICommandHandler { } } - private static async Task SocketReceived(IWebSocketConnection socket, string message) - => await LogInfoAsync(WebSocketSource, $"[{socket.ConnectionInfo.Id}] Received: {message}"); + private async Task SocketReceived(IWebSocketConnection socket, string message) { + await LogInfoAsync(WebSocketSource, $"[{socket.ConnectionInfo.Id}] Received: {message}"); + await (message switch { + "getcode" => SendClientCode(socket), + string s when s.StartsWith("login=") => ClientComputerConnected(socket, s[6..]), + _ => DisruptClientConnection(socket, "Protocol violation!") + }); + } + + private async Task ClientComputerConnected(IWebSocketConnection socket, string token) { + if (!_tokenProvider.VerifyToken(token)) { + await DisruptClientConnection(socket, "outdated"); + return; + } + await LogInfoAsync(WebSocketSource, $"[{socket.ConnectionInfo.Id}] Client logged in with valid script!"); + AddComputerSocket(socket, new(socket)); + } + + private static async Task DisruptClientConnection(IWebSocketConnection socket, string reason) { + await socket.Send(reason); + await LogWarningAsync(WebSocketSource, $"[{socket.ConnectionInfo.Id}] Client will be terminated, reason: {reason}"); + socket.Close(); + } + + private async Task SendClientCode(IWebSocketConnection socket) { + await socket.Send(GetVerifiedClientScript()); + await LogInfoAsync(WebSocketSource, $"[{socket.ConnectionInfo.Id}] Script sent to client!"); + } private void AddComputerSocket(IWebSocketConnection socket, ConnectedComputer pc) => Computer = pc; @@ -134,10 +171,7 @@ public class Program : IDisposable, ICommandHandler { await LogInfoAsync(WebSocketSource, $"[{socket.ConnectionInfo.Id}] Client disconnected!"); } - private async Task SocketOpened(IWebSocketConnection socket) { - AddComputerSocket(socket, new(socket)); - await LogInfoAsync(WebSocketSource, $"[{socket.ConnectionInfo.Id}] Client connected from {socket.ConnectionInfo.ClientIpAddress}:{socket.ConnectionInfo.ClientPort}!"); - } + private static async Task SocketOpened(IWebSocketConnection socket) => await LogInfoAsync(WebSocketSource, $"[{socket.ConnectionInfo.Id}] Client connected from {socket.ConnectionInfo.ClientIpAddress}:{socket.ConnectionInfo.ClientPort}!"); private async Task DiscordMessageReceived(SocketMessage arg, int timeout = 10000) { if (arg is not SocketUserMessage message) return; diff --git a/MinecraftDiscordBot/TimeoutTokenProvider.cs b/MinecraftDiscordBot/TimeoutTokenProvider.cs new file mode 100644 index 0000000..03cb7aa --- /dev/null +++ b/MinecraftDiscordBot/TimeoutTokenProvider.cs @@ -0,0 +1,32 @@ +namespace MinecraftDiscordBot; + +public class TimeoutTokenProvider : ITokenProvider { + public TimeoutTokenProvider(int timeoutSeconds, ICipher? cipher = null) { + _timeout = timeoutSeconds; + _cipher = cipher ?? new AesCipher(); + } + private readonly ICipher _cipher; + private readonly int _timeout; + public bool VerifyToken(string token) { + byte[] data; + try { + data = _cipher.Decrypt(Convert.FromHexString(token)); + } catch (Exception e) { + Program.LogError("TokenProvider", e); + return false; + } + var when = DateTime.FromBinary(BitConverter.ToInt64(data, 0)); + return when >= DateTime.UtcNow.AddSeconds(-_timeout); + } + public string GenerateToken() { + var time = BitConverter.GetBytes(DateTime.UtcNow.ToBinary()); + var key = Guid.NewGuid().ToByteArray(); + var token = Convert.ToHexString(_cipher.Encrypt(time.Concat(key).ToArray())); + return token; + } +} + +public interface ITokenProvider { + string GenerateToken(); + bool VerifyToken(string token); +} \ No newline at end of file