Implemented auto updating lua script

Downloads latest script from server if outdated (10 seconds)
Server sends encrypted token to client to keep session new and rejects
..old tokens
This allows updating the script in this repository
This commit is contained in:
Michael Chen 2022-01-16 21:31:07 +01:00
parent 9406aaa050
commit 0b9cb03bae
No known key found for this signature in database
GPG Key ID: 1CBC7AA5671437BB
7 changed files with 290 additions and 12 deletions

View File

@ -0,0 +1,44 @@
using System.Security.Cryptography;
namespace MinecraftDiscordBot;
public class AesCipher : ICipher {
private readonly byte[] key;
private readonly byte[] iv;
public AesCipher() {
using var aes = Aes.Create();
aes.GenerateKey();
aes.GenerateIV();
key = aes.Key;
iv = aes.IV;
}
public byte[] Encrypt(byte[] plain) {
using var aes = Aes.Create();
aes.Key = key;
aes.IV = iv;
var transformer = aes.CreateEncryptor();
using var ms = new MemoryStream();
using (var cs = new CryptoStream(ms, transformer, CryptoStreamMode.Write))
cs.Write(plain);
return ms.ToArray();
}
public byte[] Decrypt(byte[] cipher) {
using Aes aes = Aes.Create();
aes.Key = key;
aes.IV = iv;
var transformer = aes.CreateDecryptor();
using MemoryStream ms = new MemoryStream(cipher);
using CryptoStream cs = new CryptoStream(ms, transformer, CryptoStreamMode.Read);
using MemoryStream os = new MemoryStream();
cs.CopyTo(os);
return os.ToArray();
}
}
public interface ICipher {
byte[] Decrypt(byte[] cipher);
byte[] Encrypt(byte[] plain);
}

View File

@ -0,0 +1,137 @@
local secretToken = "$TOKEN"
local connectionUri = "ws://ws.cnml.de:8081"
local waitSeconds = 5
local function chunkString(value, chunkSize)
if not chunkSize then chunkSize = 10000 end
local length = value:len()
local total = math.ceil(length / chunkSize)
local chunks = {}
local i = 1
for i=1,total do
local pos = 1 + ((i - 1) * chunkSize)
chunks[i] = value:sub(pos, pos + chunkSize - 1)
end
return total, chunks
end
local function sendJson(socket, message)
return socket.send(textutils.serializeJSON(message))
end
local function sendResponse(socket, id, result, success)
if success == nil then success = true end
if not success then
sendJson(socket, { id = id, result = result, success = success })
return
end
local total, chunks = chunkString(result)
for i, chunk in pairs(chunks) do
sendJson(socket, { id = id, result = chunk, chunk = i, total = total, success = success })
end
end
-- error: no rs system
-- return rssystem rs
local function getPeripheral(name)
local dev = peripheral.find(name)
if not dev then error("No peripheral '"..name.."' attached to the computer!") end
return dev
end
-- error: any error during execution
-- return string result
local function getResponse(parsed)
if parsed.method == "energyusage" then
return tostring(getPeripheral("rsBridge").getEnergyUsage())
elseif parsed.method == "energystorage" then
return tostring(getPeripheral("rsBridge").getEnergyStorage())
elseif parsed.method == "listitems" then
return textutils.serializeJSON(getPeripheral("rsBridge").listItems())
elseif parsed.method == "listfluids" then
return textutils.serializeJSON(getPeripheral("rsBridge").listFluids())
elseif parsed.method == "craft" then
return tostring(getPeripheral("rsBridge").craftItem(parsed.params))
end
error("No message handler for method: "..parsed.method.."!")
end
local function logJSON(json, prefix)
if not prefix then prefix = "" end
for k,v in pairs(json) do
local key = prefix..k
if type(v) == "table" then
logJSON(v, key..".")
else
print(key, "=", textutils.serializeJSON(v))
end
end
end
-- return bool success
local function handleMessage(socket, message)
local parsed, reason = textutils.unserializeJSON(message)
if not parsed then
print("Received message:", message)
printError("Message could not be parsed:", reason)
return false
end
pcall(function() print("Received JSON:") logJSON(parsed) end)
if parsed.type == "request" then
local success, result = pcall(function() return getResponse(parsed) end)
sendResponse(socket, parsed.id, result, success)
return true
end
printError("Invalid message type:", parsed.type)
return false
end
local function socketClient()
print("Connecting to the socket server at "..connectionUri.."...")
local socket, reason = http.websocket(connectionUri)
if not socket then error("Socket server could not be reached: "..reason) end
print("Connection successful!")
socket.send("login="..secretToken)
while true do
local message, binary = socket.receive()
if not not message and not binary then
if message == "outdated" then
printError("Current script is outdated! Please update from the host!")
return
end
handleMessage(socket, message)
end
end
end
local function termWaiter()
os.pullEvent("terminate")
end
local function services()
parallel.waitForAny(termWaiter, function()
parallel.waitForAll(socketClient)
end)
end
local function main()
while true do
local status, error = pcall(services)
if status then break end
printError("An uncaught exception was raised:", error)
printError("Restarting in", waitSeconds, "seconds...")
sleep(waitSeconds)
end
end
local oldPullEvent = os.pullEvent
os.pullEvent = os.pullEventRaw
pcall(main)
os.pullEvent = oldPullEvent

View File

@ -27,9 +27,9 @@ public class RefinedStorageService : CommandRouter {
public override Task<ResponseType> RootAnswer(SocketUserMessage message, CancellationToken ct)
=> Task.FromResult(ResponseType.AsString("The RS system is online!"));
private async Task<T> Method<T>(string methodName, Func<string, T> parser, CancellationToken ct) {
private async Task<T> Method<T>(string methodName, Func<string, T> parser, CancellationToken ct, Dictionary<string, object>? parameters = null) {
var waiter = _taskSource.GetWaiter(parser, ct);
await _taskSource.Send(new RequestMessage(waiter.ID, methodName));
await _taskSource.Send(new RequestMessage(waiter.ID, methodName, parameters));
return await waiter.Task;
}
@ -38,11 +38,16 @@ public class RefinedStorageService : CommandRouter {
private const string CmdListItems = "listitems";
private const string CmdItemName = "itemname";
private const string CmdListFluids = "listfluids";
private const string CmdCraftItem = "craft";
public async Task<int> GetEnergyUsageAsync(CancellationToken ct) => await Method(CmdEnergyUsage, int.Parse, ct);
public async Task<int> GetEnergyStorageAsync(CancellationToken ct) => await Method(CmdEnergyStorage, int.Parse, ct);
public async Task<IEnumerable<Item>> ListItemsAsync(CancellationToken ct) => await Method(CmdListItems, ConnectedComputer.Deserialize<IEnumerable<Item>>(), ct);
public async Task<IEnumerable<Fluid>> ListFluidsAsync(CancellationToken ct) => await Method(CmdListFluids, ConnectedComputer.Deserialize<IEnumerable<Fluid>>(), ct);
public async Task<bool> CraftItem(string itemid, int amount, CancellationToken ct) => await Method(CmdCraftItem, ConnectedComputer.Deserialize<bool>(), ct, new() {
["name"] = itemid,
["count"] = amount
});
private Task<IEnumerable<Item>> FilterItems(SocketUserMessage message, IEnumerable<string> filters, CancellationToken ct)
=> FilterItems(message, filters.Select(ItemFilter.Parse), ct);
@ -72,6 +77,24 @@ public class RefinedStorageService : CommandRouter {
[CommandHandler(CmdEnergyUsage, HelpText = "Get the amount of energy used by the RS system.")]
public async Task<ResponseType> HandleEnergyUsage(SocketUserMessage message, string[] parameters, CancellationToken ct)
=> ResponseType.AsString($"Refined Storage system currently uses {await GetEnergyUsageAsync(ct)} RF/t");
[CommandHandler(CmdCraftItem, HelpText = "Craft a specific item given an item ID and optionally an amount.")]
public async Task<ResponseType> HandleCraftItem(SocketUserMessage message, string[] parameters, CancellationToken ct) {
var amount = 1;
string itemid;
if (parameters.Length is 1 or 2) {
itemid = parameters[0];
if (parameters.Length is 2)
if (int.TryParse(parameters[1], out var value)) amount = value;
else return ResponseType.AsString($"I expected an amount to craft, not '{parameters[1]}'!");
} else return parameters.Length is < 1
? ResponseType.AsString("You have to give me at least an item name!")
: parameters.Length is > 2
? ResponseType.AsString("Yo, those are way too many arguments! I want only item name and maybe an amount!")
: throw new InvalidOperationException($"Forgot to match parameter length {parameters.Length}!");
return await CraftItem(itemid, amount, ct)
? ResponseType.AsString($"Alright, I'm starting to craft {amount} {itemid}.")
: ResponseType.AsString($"Nope, that somehow doesn't work!");
}
[CommandHandler(CmdItemName, HelpText = "Filter items by name.")]
public async Task<ResponseType> HandleItemName(SocketUserMessage message, string[] parameters, CancellationToken ct) {
if (parameters.Length < 2) return ResponseType.AsString($"Usage: {CmdItemName} filters...");

View File

@ -49,10 +49,10 @@ public class ReplyMessage : Message {
}
public class RequestMessage : Message {
public RequestMessage(int answerId, string method, Dictionary<string, string>? parameters = null) {
public RequestMessage(int answerId, string method, Dictionary<string, object>? parameters = null) {
AnswerId = answerId;
Method = method;
Parameters = (parameters ?? Enumerable.Empty<KeyValuePair<string, string>>())
Parameters = (parameters ?? Enumerable.Empty<KeyValuePair<string, object>>())
.ToDictionary(i => i.Key, i => i.Value);
}
[JsonProperty("id")]
@ -60,6 +60,6 @@ public class RequestMessage : Message {
[JsonProperty("method")]
public string Method { get; set; }
[JsonProperty("params")]
public Dictionary<string, string> Parameters { get; }
public Dictionary<string, object> Parameters { get; }
public override string Type => "request";
}

View File

@ -15,6 +15,10 @@
<FileVersion>$(VersionPrefix)</FileVersion>
</PropertyGroup>
<ItemGroup>
<EmbeddedResource Include="ClientScript.lua" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="CommandLineParser" Version="2.8.0" />
<PackageReference Include="Discord.Net" Version="3.1.0" />
@ -24,4 +28,8 @@
<PackageReference Include="OneOf" Version="3.0.205" />
</ItemGroup>
<ItemGroup>
<Resource Include="ClientScript.lua" />
</ItemGroup>
</Project>

View File

@ -30,6 +30,17 @@ public class Program : IDisposable, ICommandHandler<ResponseType> {
private ConnectedComputer? _rsSystem = null;
private bool disposedValue;
public static bool OnlineNotifications => false;
public static readonly string ClientScript = GetClientScript();
private readonly ITokenProvider _tokenProvider = new TimeoutTokenProvider(10);
private string GetVerifiedClientScript() => ClientScript.Replace("$TOKEN", _tokenProvider.GenerateToken());
private static string GetClientScript() {
using var stream = Assembly.GetExecutingAssembly().GetManifestResourceStream("MinecraftDiscordBot.ClientScript.lua");
if (stream is null) throw new FileNotFoundException("Client script could not be loaded!");
using var sr = new StreamReader(stream);
return sr.ReadToEnd();
}
public ConnectedComputer? Computer {
get => _rsSystem; set {
@ -120,8 +131,34 @@ public class Program : IDisposable, ICommandHandler<ResponseType> {
}
}
private static async Task SocketReceived(IWebSocketConnection socket, string message)
=> await LogInfoAsync(WebSocketSource, $"[{socket.ConnectionInfo.Id}] Received: {message}");
private async Task SocketReceived(IWebSocketConnection socket, string message) {
await LogInfoAsync(WebSocketSource, $"[{socket.ConnectionInfo.Id}] Received: {message}");
await (message switch {
"getcode" => SendClientCode(socket),
string s when s.StartsWith("login=") => ClientComputerConnected(socket, s[6..]),
_ => DisruptClientConnection(socket, "Protocol violation!")
});
}
private async Task ClientComputerConnected(IWebSocketConnection socket, string token) {
if (!_tokenProvider.VerifyToken(token)) {
await DisruptClientConnection(socket, "outdated");
return;
}
await LogInfoAsync(WebSocketSource, $"[{socket.ConnectionInfo.Id}] Client logged in with valid script!");
AddComputerSocket(socket, new(socket));
}
private static async Task DisruptClientConnection(IWebSocketConnection socket, string reason) {
await socket.Send(reason);
await LogWarningAsync(WebSocketSource, $"[{socket.ConnectionInfo.Id}] Client will be terminated, reason: {reason}");
socket.Close();
}
private async Task SendClientCode(IWebSocketConnection socket) {
await socket.Send(GetVerifiedClientScript());
await LogInfoAsync(WebSocketSource, $"[{socket.ConnectionInfo.Id}] Script sent to client!");
}
private void AddComputerSocket(IWebSocketConnection socket, ConnectedComputer pc) => Computer = pc;
@ -134,10 +171,7 @@ public class Program : IDisposable, ICommandHandler<ResponseType> {
await LogInfoAsync(WebSocketSource, $"[{socket.ConnectionInfo.Id}] Client disconnected!");
}
private async Task SocketOpened(IWebSocketConnection socket) {
AddComputerSocket(socket, new(socket));
await LogInfoAsync(WebSocketSource, $"[{socket.ConnectionInfo.Id}] Client connected from {socket.ConnectionInfo.ClientIpAddress}:{socket.ConnectionInfo.ClientPort}!");
}
private static async Task SocketOpened(IWebSocketConnection socket) => await LogInfoAsync(WebSocketSource, $"[{socket.ConnectionInfo.Id}] Client connected from {socket.ConnectionInfo.ClientIpAddress}:{socket.ConnectionInfo.ClientPort}!");
private async Task DiscordMessageReceived(SocketMessage arg, int timeout = 10000) {
if (arg is not SocketUserMessage message) return;

View File

@ -0,0 +1,32 @@
namespace MinecraftDiscordBot;
public class TimeoutTokenProvider : ITokenProvider {
public TimeoutTokenProvider(int timeoutSeconds, ICipher? cipher = null) {
_timeout = timeoutSeconds;
_cipher = cipher ?? new AesCipher();
}
private readonly ICipher _cipher;
private readonly int _timeout;
public bool VerifyToken(string token) {
byte[] data;
try {
data = _cipher.Decrypt(Convert.FromHexString(token));
} catch (Exception e) {
Program.LogError("TokenProvider", e);
return false;
}
var when = DateTime.FromBinary(BitConverter.ToInt64(data, 0));
return when >= DateTime.UtcNow.AddSeconds(-_timeout);
}
public string GenerateToken() {
var time = BitConverter.GetBytes(DateTime.UtcNow.ToBinary());
var key = Guid.NewGuid().ToByteArray();
var token = Convert.ToHexString(_cipher.Encrypt(time.Concat(key).ToArray()));
return token;
}
}
public interface ITokenProvider {
string GenerateToken();
bool VerifyToken(string token);
}